Skip to content

Commit 6671ae6

Browse files
committed
Fix alias elimination
1 parent d4b8d25 commit 6671ae6

File tree

3 files changed

+58
-95
lines changed

3 files changed

+58
-95
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,3 @@ Base.:(==)(sys1::ODESystem, sys2::ODESystem) =
149149
function rename(sys::ODESystem,name)
150150
ODESystem(sys.eqs, sys.iv, sys.states, sys.ps, sys.pins, sys.observed, sys.tgrad, sys.jac, sys.Wfact, sys.Wfact_t, name, sys.systems)
151151
end
152-
153-
isdiffeq(eq) = eq.lhs isa Term && operation(eq.lhs) isa Differential

src/systems/reduction.jl

Lines changed: 18 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,22 @@ function substitute_aliases(diffeqs, dict)
2929
lhss(diffeqs) .~ fixpoint_sub.(rhss(diffeqs), (dict,))
3030
end
3131

32-
isvar(s::Sym; param=false) = param ? true : !isparameter(s)
33-
isvar(s::Term; param=false) = isvar(s.op; param=param)
34-
isvar(s::Any;param=false) = false
32+
# Note that we reduce parameters, too
33+
# i.e. `2param = 3` will be reduced away
34+
isvar(s::Sym) = true
35+
isvar(s::Term) = isvar(operation(s))
36+
isvar(s::Any) = false
3537

3638
function get_α_x(αx)
37-
if isvar(αx, param=true)
39+
if isvar(αx)
3840
return 1, αx
3941
elseif αx isa Term && operation(αx) === (*)
4042
args = arguments(αx)
41-
nums = filter(!isvar, args)
42-
syms = filter(isvar, args)
43+
nums = []
44+
syms = []
45+
for arg in args
46+
isvar(arg) ? push!(syms, arg) : push!(nums, arg)
47+
end
4348

4449
if length(syms) == 1
4550
return prod(nums), syms[1]
@@ -51,20 +56,7 @@ end
5156

5257
function alias_elimination(sys::ODESystem)
5358
eqs = vcat(equations(sys), observed(sys))
54-
5559
subs = Pair[]
56-
# Case 1: Right hand side is a constant
57-
ii = findall(eqs) do eq
58-
!(eq.rhs isa Symbolic)
59-
end
60-
for eq in eqs[ii]
61-
α,x = get_α_x(eq.lhs)
62-
push!(subs, x => isone(α) ? eq.rhs : eq.rhs / α)
63-
end
64-
deleteat!(eqs, ii) # remove them
65-
66-
# Case 2: One side is a differentiated var, the other is an algebraic var
67-
# substitute the algebraic var with the diff var
6860
diff_vars = filter(!isnothing, map(eqs) do eq
6961
if isdiffeq(eq)
7062
eq.lhs.args[1]
@@ -73,41 +65,19 @@ function alias_elimination(sys::ODESystem)
7365
end
7466
end) |> Set
7567

76-
del = Int[]
77-
for (i, eq) in enumerate(eqs)
78-
res_left = get_α_x(eq.lhs)
79-
if !isnothing(res_left)
80-
α, x = res_left
81-
res_right = get_α_x(eq.rhs)
82-
if !isnothing(res_right)
83-
β, y = res_right
84-
if y in diff_vars && !(x in diff_vars)
85-
multiple = β / α
86-
push!(subs, x => isone(multiple) ? y : multiple * y)
87-
push!(del, i)
88-
elseif x in diff_vars && !(y in diff_vars)
89-
multiple = α / β
90-
push!(subs, y => isone(multiple) ? x : multiple * x)
91-
push!(del, i)
92-
end
93-
end
94-
end
95-
end
96-
deleteat!(eqs, del)
97-
98-
# Case 3: Explicit substitutions
68+
# only substitute when the variable is algebraic
9969
del = Int[]
10070
for (i, eq) in enumerate(eqs)
10171
isdiffeq(eq) && continue
10272
res_left = get_α_x(eq.lhs)
103-
if !isnothing(res_left)
73+
if !isnothing(res_left) && !(res_left[2] in diff_vars)
10474
# `α x = rhs` => `x = rhs / α`
10575
α, x = res_left
10676
push!(subs, x => _isone(α) ? eq.rhs : eq.rhs / α)
10777
push!(del, i)
10878
else
10979
res_right = get_α_x(eq.rhs)
110-
if !isnothing(res_right)
80+
if !isnothing(res_right) && !(res_right[2] in diff_vars)
11181
# `lhs = β y` => `y = lhs / β`
11282
β, y = res_right
11383
push!(subs, y => _isone(β) ? eq.lhs : β * eq.lhs)
@@ -117,11 +87,9 @@ function alias_elimination(sys::ODESystem)
11787
end
11888
deleteat!(eqs, del)
11989

120-
diffeqs = filter(eq -> eq.lhs isa Term && eq.lhs.op isa Differential, eqs)
121-
diffeqs′ = substitute_aliases(diffeqs, Dict(subs))
90+
eqs′ = substitute_aliases(eqs, Dict(subs))
91+
alias_vars = first.(subs)
12292

123-
newstates = map(diffeqs) do eq
124-
eq.lhs.args[1]
125-
end
126-
ODESystem(diffeqs′, sys.iv, newstates, parameters(sys), observed=first.(subs) .~ last.(subs))
93+
newstates = setdiff(states(sys), alias_vars)
94+
ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_vars .~ last.(subs))
12795
end

test/reduction.jl

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,55 +7,52 @@ using ModelingToolkit, OrdinaryDiffEq, Test
77
test_equal(a, b) = @test isequal(simplify(a, polynorm=true), simplify(b, polynorm=true))
88

99
eqs = [D(x) ~ σ*(y-x),
10-
D(y) ~ x*-z)-y,
11-
D(z) ~ a*y - β*z,
12-
β ~ 2,
13-
x ~ a]
10+
D(y) ~ x*-z)-y + β,
11+
0 ~ sin(z) - x + y,
12+
sin(u) ~ x + y,
13+
2β ~ 2,
14+
x ~ a,
15+
]
1416

15-
lorenz1 = ODESystem(eqs,t,[x,y,z,a],[σ,ρ,β],name=:lorenz1)
17+
lorenz1 = ODESystem(eqs,t,[u,x,y,z,a],[σ,ρ,β],name=:lorenz1)
1618

1719
lorenz1_aliased = alias_elimination(lorenz1)
18-
@test length(equations(lorenz1_aliased)) == 3
19-
@test length(states(lorenz1_aliased)) == 3
20-
21-
eqs = [D(x) ~ σ*(y-x),
22-
D(y) ~ x*-z)-y,
23-
D(z) ~ x*y - 2*z]
24-
25-
# TODO: maybe remove β from ps, or maybe don't allow this example on params
26-
@test lorenz1_aliased == ODESystem(eqs,t,[x,y,z],[σ,ρ,β],observed=~ 2, a ~ x],name=:lorenz1)
20+
reduced_eqs = [
21+
D(x) ~ σ * (y - x),
22+
D(y) ~ x*-z)-y + 1,
23+
0 ~ sin(z) - x + y,
24+
sin(u) ~ x + y,
25+
]
26+
test_equal.(equations(lorenz1_aliased), reduced_eqs)
27+
test_equal.(states(lorenz1_aliased), [u, x, y, z])
28+
test_equal.(observed(lorenz1_aliased), [
29+
β ~ 1,
30+
a ~ x,
31+
])
2732

2833
# Multi-System Reduction
2934

30-
eqs1 = [D(x) ~ σ*(y-x) + F,
31-
D(y) ~ x*-z)-u,
32-
D(z) ~ x*y - β*z]
33-
34-
aliases = [u ~ x + y - z]
35+
eqs1 = [
36+
D(x) ~ σ*(y-x) + F,
37+
D(y) ~ x*-z)-u,
38+
D(z) ~ x*y - β*z,
39+
u ~ x + y - z,
40+
]
3541

36-
lorenz1 = ODESystem(eqs1,pins=[F],observed=aliases,name=:lorenz1)
42+
lorenz1 = ODESystem(eqs1,pins=[F],name=:lorenz1)
3743

38-
eqs2 = [D(x) ~ F,
39-
D(y) ~ x*-z)-x,
40-
D(z) ~ x*y - β*z]
44+
eqs2 = [
45+
D(x) ~ F,
46+
D(y) ~ x*-z)-x,
47+
D(z) ~ x*y - β*z,
48+
u ~ x - y - z
49+
]
4150

42-
aliases2 = [u ~ x - y - z]
51+
lorenz2 = ODESystem(eqs2,pins=[F],name=:lorenz2)
4352

44-
lorenz2 = ODESystem(eqs2,pins=[F],observed=aliases2,name=:lorenz2)
45-
46-
connections = [lorenz1.F ~ lorenz2.u,
47-
lorenz2.F ~ lorenz1.u]
48-
49-
connected = ODESystem([lorenz2.y ~ a + lorenz1.x],t,[a],[],observed=connections,systems=[lorenz1,lorenz2])
50-
51-
# Reduced Unflattened System
52-
#=
53-
54-
connections2 = [lorenz1.F ~ lorenz2.u,
55-
lorenz2.F ~ lorenz1.u,
56-
a ~ -lorenz1.x + lorenz2.y]
57-
connected = ODESystem(Equation[],t,[],[],observed=connections2,systems=[lorenz1,lorenz2])
58-
=#
53+
connected = ODESystem([lorenz2.y ~ a + lorenz1.x,
54+
lorenz1.F ~ lorenz2.u,
55+
lorenz2.F ~ lorenz1.u],t,[a],[],systems=[lorenz1,lorenz2])
5956

6057
# Reduced Flattened System
6158

@@ -64,6 +61,7 @@ flattened_system = ModelingToolkit.flatten(connected)
6461
aliased_flattened_system = alias_elimination(flattened_system)
6562

6663
@test isequal(states(aliased_flattened_system), [
64+
a
6765
lorenz1.x
6866
lorenz1.y
6967
lorenz1.z
@@ -83,25 +81,24 @@ aliased_flattened_system = alias_elimination(flattened_system)
8381
]) |> isempty
8482

8583
reduced_eqs = [
86-
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - (a + lorenz1.x) - lorenz2.z,
84+
lorenz2.y ~ a + lorenz1.x, # irreducible by alias elimination
85+
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z,
8786
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z),
8887
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z,
8988
D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z,
9089
D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x,
91-
D(lorenz2.z) ~ lorenz2.x*(a + lorenz1.x) - lorenz2.β*lorenz2.z
90+
D(lorenz2.z) ~ lorenz2.x*lorenz2.y - lorenz2.β*lorenz2.z
9291
]
9392
test_equal.(equations(aliased_flattened_system), reduced_eqs)
9493

9594
observed_eqs = [
96-
lorenz2.y ~ a + lorenz1.x,
9795
lorenz1.F ~ lorenz2.u,
9896
lorenz2.F ~ lorenz1.u,
9997
lorenz1.u ~ lorenz1.x + lorenz1.y - lorenz1.z,
10098
lorenz2.u ~ lorenz2.x - lorenz2.y - lorenz2.z,
10199
]
102100
test_equal.(observed(aliased_flattened_system), observed_eqs)
103101

104-
105102
# issue #578
106103

107104
let

0 commit comments

Comments
 (0)