Skip to content

Commit d67fe45

Browse files
committed
More robust conservative alias elimination, and better tests
1 parent 7b9ec75 commit d67fe45

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

src/systems/reduction.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,24 @@ function get_α_x(αx)
5555
end
5656
end
5757

58-
function is_sub_candidate(rhs, conservative)
58+
function is_univariate_expr(ex, iv)
59+
count = 0
60+
for var in vars(ex)
61+
if !isequal(iv, var) && !isparameter(var)
62+
count += 1
63+
count > 1 && return false
64+
end
65+
end
66+
return count <= 1
67+
end
68+
69+
function is_sub_candidate(ex, iv, conservative)
5970
conservative || return true
60-
isvar(rhs) || rhs isa Number
71+
isvar(ex) || ex isa Number || is_univariate_expr(ex, iv)
6172
end
6273

63-
function maybe_alias(lhs, rhs, diff_vars, conservative)
64-
is_sub_candidate(rhs, conservative) || return false, nothing
74+
function maybe_alias(lhs, rhs, diff_vars, iv, conservative)
75+
is_sub_candidate(rhs, iv, conservative) || return false, nothing
6576

6677
res_left = get_α_x(lhs)
6778
if res_left !== nothing && !(res_left[2] in diff_vars)
@@ -74,6 +85,7 @@ function maybe_alias(lhs, rhs, diff_vars, conservative)
7485
end
7586

7687
function alias_elimination(sys::ODESystem; conservative=true)
88+
iv = independent_variable(sys)
7789
eqs = vcat(equations(sys), observed(sys))
7890
diff_vars = filter(!isnothing, map(eqs) do eq
7991
if isdiffeq(eq)
@@ -95,10 +107,10 @@ function alias_elimination(sys::ODESystem; conservative=true)
95107
end
96108

97109
# `α x = rhs` => `x = rhs / α`
98-
ma, sub = maybe_alias(eq.lhs, eq.rhs, diff_vars, conservative)
110+
ma, sub = maybe_alias(eq.lhs, eq.rhs, diff_vars, iv, conservative)
99111
if !ma
100112
# `lhs = β y` => `y = lhs / β`
101-
ma, sub = maybe_alias(eq.rhs, eq.lhs, diff_vars, conservative)
113+
ma, sub = maybe_alias(eq.rhs, eq.lhs, diff_vars, iv, conservative)
102114
end
103115

104116
isalias = false

test/reduction.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ eqs = [
3434
D(y) ~ x*-z)-y + β
3535
0 ~ sin(z) - x + y
3636
sin(u) ~ x + y
37-
x ~ a
37+
2x ~ 3a
38+
2u ~ 3x
3839
]
3940

4041
lorenz1 = ODESystem(eqs,t,name=:lorenz1)
@@ -49,8 +50,9 @@ reduced_eqs = [
4950
test_equal.(equations(lorenz1_aliased), reduced_eqs)
5051
@test isempty(setdiff(states(lorenz1_aliased), [u, x, y, z]))
5152
test_equal.(observed(lorenz1_aliased), [
52-
a ~ x,
53-
])
53+
a ~ 2/3 * x,
54+
u ~ 3/2 * x,
55+
])
5456

5557
# Multi-System Reduction
5658

0 commit comments

Comments
 (0)