Skip to content

Commit a169988

Browse files
committed
Alias elimination with cycle removal and canonicalization
1 parent b799532 commit a169988

File tree

2 files changed

+51
-21
lines changed

2 files changed

+51
-21
lines changed

src/systems/reduction.jl

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ end
5656

5757
function alias_elimination(sys::ODESystem)
5858
eqs = vcat(equations(sys), observed(sys))
59+
neweqs = Equation[]; sizehint!(neweqs, length(eqs))
5960
subs = Pair[]
6061
diff_vars = filter(!isnothing, map(eqs) do eq
6162
if isdiffeq(eq)
@@ -65,33 +66,56 @@ function alias_elimination(sys::ODESystem)
6566
end
6667
end) |> Set
6768

68-
# only substitute when the variable is algebraic
69-
del = Int[]
69+
deps = Set()
7070
for (i, eq) in enumerate(eqs)
71-
isdiffeq(eq) && continue
71+
# only substitute when the variable is algebraic
72+
if isdiffeq(eq)
73+
push!(neweqs, eq)
74+
continue
75+
end
76+
77+
maybe_alias = isalias = false
7278
res_left = get_α_x(eq.lhs)
7379
if !isnothing(res_left) && !(res_left[2] in diff_vars)
7480
# `α x = rhs` => `x = rhs / α`
7581
α, x = res_left
76-
push!(subs, x => _isone(α) ? eq.rhs : eq.rhs / α)
77-
push!(del, i)
82+
sub = x => _isone(α) ? eq.rhs : eq.rhs / α
83+
maybe_alias = true
7884
else
7985
res_right = get_α_x(eq.rhs)
8086
if !isnothing(res_right) && !(res_right[2] in diff_vars)
8187
# `lhs = β y` => `y = lhs / β`
8288
β, y = res_right
83-
push!(subs, y => _isone(β) ? eq.lhs : β * eq.lhs)
84-
push!(del, i)
89+
sub = y => _isone(β) ? eq.lhs : β * eq.lhs
90+
maybe_alias = true
91+
end
92+
end
93+
94+
if maybe_alias
95+
l, r = sub
96+
# alias equations shouldn't introduce cycles
97+
if !(l in deps) && isempty(intersect(deps, vars(r)))
98+
push!(deps, l)
99+
push!(subs, sub)
100+
isalias = true
85101
end
86102
end
103+
104+
if !isalias
105+
neweq = _iszero(eq.lhs) ? eq : 0 ~ eq.rhs - eq.lhs
106+
push!(neweqs, neweq)
107+
end
87108
end
88-
deleteat!(eqs, del)
89109

90-
eqs′ = substitute_aliases(eqs, Dict(subs))
110+
eqs′ = substitute_aliases(neweqs, Dict(subs))
111+
91112
alias_vars = first.(subs)
113+
sys_states = states(sys)
114+
alias_eqs = alias_vars .~ last.(subs)
115+
#alias_eqs = topsort_equations(alias_eqs, sys_states)
92116

93-
newstates = setdiff(states(sys), alias_vars)
94-
ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_vars .~ last.(subs))
117+
newstates = setdiff(sys_states, alias_vars)
118+
ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_eqs)
95119
end
96120

97121
"""

test/reduction.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ reduced_eqs = [
4444
D(x) ~ σ * (y - x),
4545
D(y) ~ x*-z)-y + 1,
4646
0 ~ sin(z) - x + y,
47-
sin(u) ~ x + y,
47+
0 ~ x + y - sin(u),
4848
]
4949
test_equal.(equations(lorenz1_aliased), reduced_eqs)
5050
test_equal.(states(lorenz1_aliased), [u, x, y, z])
@@ -104,7 +104,7 @@ aliased_flattened_system = alias_elimination(flattened_system)
104104
]) |> isempty
105105

106106
reduced_eqs = [
107-
lorenz2.y ~ a + lorenz1.x, # irreducible by alias elimination
107+
0 ~ a + lorenz1.x - lorenz2.y, # irreducible by alias elimination
108108
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z,
109109
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z),
110110
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z,
@@ -138,22 +138,28 @@ let
138138
test_equal.(asys.observed, [y ~ x])
139139
end
140140

141-
# issue #716
141+
# issue #724 and #716
142142
let
143143
@parameters t
144144
D = Differential(t)
145145
@variables x(t), u(t), y(t)
146146
@parameters a, b, c, d
147-
ol = ODESystem([D(x) ~ a * x + b * u, y ~ c * x], t, name=:ol)
147+
ol = ODESystem([D(x) ~ a * x + b * u; y ~ c * x + d * u], t, pins=[u], name=:ol)
148148
@variables u_c(t), y_c(t)
149149
@parameters k_P
150-
pc = ODESystem(Equation[], t, pins=[y_c], observed = [u_c ~ k_P * y_c], name=:pc)
150+
pc = ODESystem(Equation[u_c ~ k_P * y_c], t, pins=[y_c], name=:pc)
151151
connections = [
152-
ol.u ~ pc.u_c,
153-
y_c ~ ol.y
154-
]
152+
ol.u ~ pc.u_c,
153+
pc.y_c ~ ol.y
154+
]
155155
connected = ODESystem(connections, t, systems=[ol, pc])
156-
157156
@test equations(connected) isa Vector{Equation}
158-
@test_nowarn flatten(connected)
157+
sys = flatten(connected)
158+
reduced_sys = alias_elimination(sys)
159+
ref_eqs = [
160+
D(ol.x) ~ ol.a*ol.x + ol.b*pc.u_c
161+
0 ~ ol.c*ol.x + ol.d*pc.u_c - ol.y
162+
0 ~ pc.k_P*ol.y - pc.u_c
163+
]
164+
@test ref_eqs == equations(reduced_sys)
159165
end

0 commit comments

Comments
 (0)