Skip to content

Commit f25c9be

Browse files
authored
Merge pull request #695 from SciML/s/alias_elim2
Simpler alias elimination
2 parents 195d6a5 + 80d02c6 commit f25c9be

File tree

5 files changed

+136
-105
lines changed

5 files changed

+136
-105
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,5 @@ end
418418
function SteadyStateProblemExpr(sys::AbstractODESystem, args...; kwargs...)
419419
SteadyStateProblemExpr{true}(sys, args...; kwargs...)
420420
end
421+
422+
isdiffeq(eq) = eq.lhs isa Term && operation(eq.lhs) isa Differential

src/systems/diffeqs/odesystem.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,28 +119,34 @@ function ODESystem(eqs, iv=nothing; kwargs...)
119119
end
120120
iv === nothing && throw(ArgumentError("Please pass in independent variables."))
121121
for eq in eqs
122-
for var in vars(eq.rhs for eq eqs)
123-
if isparameter(var) || isparameter(var.op)
124-
isequal(var, iv) || push!(ps, var)
125-
else
126-
push!(allstates, var)
127-
end
128-
end
129-
if !(eq.lhs isa Symbolic)
130-
push!(algeeq, eq)
131-
else
132-
diffvar = first(var_from_nested_derivative(eq.lhs))
122+
collect_vars!(allstates, ps, eq.lhs, iv)
123+
collect_vars!(allstates, ps, eq.rhs, iv)
124+
if isdiffeq(eq)
125+
diffvar, _ = var_from_nested_derivative(eq.lhs)
133126
isequal(iv, iv_from_nested_derivative(eq.lhs)) || throw(ArgumentError("An ODESystem can only have one independent variable."))
134127
diffvar in diffvars && throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
135128
push!(diffvars, diffvar)
136129
push!(diffeq, eq)
130+
else
131+
push!(algeeq, eq)
137132
end
138133
end
139134
algevars = setdiff(allstates, diffvars)
140135
# the orders here are very important!
141136
return ODESystem(append!(diffeq, algeeq), iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...)
142137
end
143138

139+
function collect_vars!(states, parameters, expr, iv)
140+
for var in vars(expr)
141+
if isparameter(var) || isparameter(var.op)
142+
isequal(var, iv) || push!(parameters, var)
143+
else
144+
push!(states, var)
145+
end
146+
end
147+
return nothing
148+
end
149+
144150
Base.:(==)(sys1::ODESystem, sys2::ODESystem) =
145151
_eq_unordered(sys1.eqs, sys2.eqs) && isequal(sys1.iv, sys2.iv) &&
146152
_eq_unordered(sys1.states, sys2.states) && _eq_unordered(sys1.ps, sys2.ps)

src/systems/reduction.jl

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ function flatten(sys::ODESystem)
66
else
77
return ODESystem(equations(sys),
88
independent_variable(sys),
9+
states(sys),
10+
parameters(sys),
911
observed=observed(sys))
1012
end
1113
end
@@ -27,59 +29,67 @@ function substitute_aliases(diffeqs, dict)
2729
lhss(diffeqs) .~ fixpoint_sub.(rhss(diffeqs), (dict,))
2830
end
2931

30-
isvar(s::Sym) = !isparameter(s)
31-
isvar(s::Term) = isvar(s.op)
32+
# Note that we reduce parameters, too
33+
# i.e. `2param = 3` will be reduced away
34+
isvar(s::Sym) = true
35+
isvar(s::Term) = isvar(operation(s))
3236
isvar(s::Any) = false
3337

34-
function filterexpr(f, s)
35-
vs = []
36-
Rewriters.Prewalk(Rewriters.Chain([@rule((~x::f) => push!(vs, ~x))]))(s)
37-
vs
38-
end
38+
function get_α_x(αx)
39+
if isvar(αx)
40+
return 1, αx
41+
elseif αx isa Term && operation(αx) === (*)
42+
args = arguments(αx)
43+
nums = []
44+
syms = []
45+
for arg in args
46+
isvar(arg) ? push!(syms, arg) : push!(nums, arg)
47+
end
3948

40-
function make_lhs_0(eq)
41-
if eq.lhs isa Number && iszero(eq.lhs)
42-
return eq
49+
if length(syms) == 1
50+
return prod(nums), syms[1]
51+
end
4352
else
44-
0 ~ eq.lhs - eq.rhs
53+
return nothing
4554
end
4655
end
4756

4857
function alias_elimination(sys::ODESystem)
4958
eqs = vcat(equations(sys), observed(sys))
50-
51-
# make all algebraic equations have 0 on LHS
52-
eqs = map(eqs) do eq
53-
if eq.lhs isa Term && eq.lhs.op isa Differential
54-
eq
59+
subs = Pair[]
60+
diff_vars = filter(!isnothing, map(eqs) do eq
61+
if isdiffeq(eq)
62+
eq.lhs.args[1]
63+
else
64+
nothing
65+
end
66+
end) |> Set
67+
68+
# only substitute when the variable is algebraic
69+
del = Int[]
70+
for (i, eq) in enumerate(eqs)
71+
isdiffeq(eq) && continue
72+
res_left = get_α_x(eq.lhs)
73+
if !isnothing(res_left) && !(res_left[2] in diff_vars)
74+
# `α x = rhs` => `x = rhs / α`
75+
α, x = res_left
76+
push!(subs, x => _isone(α) ? eq.rhs : eq.rhs / α)
77+
push!(del, i)
5578
else
56-
make_lhs_0(eq)
79+
res_right = get_α_x(eq.rhs)
80+
if !isnothing(res_right) && !(res_right[2] in diff_vars)
81+
# `lhs = β y` => `y = lhs / β`
82+
β, y = res_right
83+
push!(subs, y => _isone(β) ? eq.lhs : β * eq.lhs)
84+
push!(del, i)
85+
end
5786
end
5887
end
88+
deleteat!(eqs, del)
5989

60-
newstates = map(eqs) do eq
61-
if eq.lhs isa Term && eq.lhs.op isa Differential
62-
filterexpr(isvar, eq.lhs)
63-
else
64-
[]
65-
end
66-
end |> Iterators.flatten |> collect |> unique
67-
90+
eqs′ = substitute_aliases(eqs, Dict(subs))
91+
alias_vars = first.(subs)
6892

69-
all_vars = map(eqs) do eq
70-
filterexpr(isvar, eq.rhs)
71-
end |> Iterators.flatten |> collect |> unique
72-
73-
alg_idxs = findall(x->!(x.lhs isa Term) && iszero(x.lhs), eqs)
74-
75-
eliminate = setdiff(all_vars, newstates)
76-
77-
outputs = solve_for(eqs[alg_idxs], eliminate)
78-
79-
diffeqs = eqs[setdiff(1:length(eqs), alg_idxs)]
80-
81-
diffeqs′ = substitute_aliases(diffeqs, Dict(eliminate .=> outputs))
82-
83-
ODESystem(diffeqs′, sys.iv, newstates, parameters(sys), observed=eliminate .~ outputs)
93+
newstates = setdiff(states(sys), alias_vars)
94+
ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_vars .~ last.(subs))
8495
end
85-

test/odesystem.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,14 @@ for (prob, atol) in [(prob1, 1e-12), (prob2, 1e-12), (prob3, 1e-12)]
200200
sol = solve(prob, Rodas5())
201201
@test all(x->(sum(x), 1.0, atol=atol), sol.u)
202202
end
203+
204+
@parameters t σ β
205+
@variables x(t) y(t) z(t)
206+
@derivatives D'~t
207+
eqs = [D(x) ~ σ*(y-x),
208+
D(y) ~ x-β*y,
209+
x + z ~ y]
210+
sys = ODESystem(eqs)
211+
@test all(isequal.(states(sys), [x, y, z]))
212+
@test all(isequal.(parameters(sys), [σ, β]))
213+
@test equations(sys) == eqs

test/reduction.jl

Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,53 +7,52 @@ using ModelingToolkit, OrdinaryDiffEq, Test
77
test_equal(a, b) = @test isequal(simplify(a, polynorm=true), simplify(b, polynorm=true))
88

99
eqs = [D(x) ~ σ*(y-x),
10-
D(y) ~ x*-z)-y,
11-
D(z) ~ a*y - β*z,
12-
0 ~ x - a]
10+
D(y) ~ x*-z)-y + β,
11+
0 ~ sin(z) - x + y,
12+
sin(u) ~ x + y,
13+
2β ~ 2,
14+
x ~ a,
15+
]
1316

14-
lorenz1 = ODESystem(eqs,t,[x,y,z,a],[σ,ρ,β],name=:lorenz1)
17+
lorenz1 = ODESystem(eqs,t,[u,x,y,z,a],[σ,ρ,β],name=:lorenz1)
1518

1619
lorenz1_aliased = alias_elimination(lorenz1)
17-
@test length(equations(lorenz1_aliased)) == 3
18-
@test length(states(lorenz1_aliased)) == 3
19-
20-
eqs = [D(x) ~ σ*(y-x),
21-
D(y) ~ x*-z)-y,
22-
D(z) ~ x*y - β*z]
23-
24-
@test lorenz1_aliased == ODESystem(eqs,t,[x,y,z],[σ,ρ,β],observed=[a ~ x],name=:lorenz1)
20+
reduced_eqs = [
21+
D(x) ~ σ * (y - x),
22+
D(y) ~ x*-z)-y + 1,
23+
0 ~ sin(z) - x + y,
24+
sin(u) ~ x + y,
25+
]
26+
test_equal.(equations(lorenz1_aliased), reduced_eqs)
27+
test_equal.(states(lorenz1_aliased), [u, x, y, z])
28+
test_equal.(observed(lorenz1_aliased), [
29+
β ~ 1,
30+
a ~ x,
31+
])
2532

2633
# Multi-System Reduction
2734

28-
eqs1 = [D(x) ~ σ*(y-x) + F,
29-
D(y) ~ x*-z)-u,
30-
D(z) ~ x*y - β*z]
31-
32-
aliases = [u ~ x + y - z]
33-
34-
lorenz1 = ODESystem(eqs1,pins=[F],observed=aliases,name=:lorenz1)
35-
36-
eqs2 = [D(x) ~ F,
37-
D(y) ~ x*-z)-x,
38-
D(z) ~ x*y - β*z]
35+
eqs1 = [
36+
D(x) ~ σ*(y-x) + F,
37+
D(y) ~ x*-z)-u,
38+
D(z) ~ x*y - β*z,
39+
u ~ x + y - z,
40+
]
3941

40-
aliases2 = [u ~ x - y - z]
42+
lorenz1 = ODESystem(eqs1,pins=[F],name=:lorenz1)
4143

42-
lorenz2 = ODESystem(eqs2,pins=[F],observed=aliases2,name=:lorenz2)
44+
eqs2 = [
45+
D(x) ~ F,
46+
D(y) ~ x*-z)-x,
47+
D(z) ~ x*y - β*z,
48+
u ~ x - y - z
49+
]
4350

44-
connections = [lorenz1.F ~ lorenz2.u,
45-
lorenz2.F ~ lorenz1.u]
51+
lorenz2 = ODESystem(eqs2,pins=[F],name=:lorenz2)
4652

47-
connected = ODESystem([0 ~ a + lorenz1.x - lorenz2.y],t,[a],[],observed=connections,systems=[lorenz1,lorenz2])
48-
49-
# Reduced Unflattened System
50-
#=
51-
52-
connections2 = [lorenz1.F ~ lorenz2.u,
53-
lorenz2.F ~ lorenz1.u,
54-
a ~ -lorenz1.x + lorenz2.y]
55-
connected = ODESystem(Equation[],t,[],[],observed=connections2,systems=[lorenz1,lorenz2])
56-
=#
53+
connected = ODESystem([lorenz2.y ~ a + lorenz1.x,
54+
lorenz1.F ~ lorenz2.u,
55+
lorenz2.F ~ lorenz1.u],t,[a],[],systems=[lorenz1,lorenz2])
5756

5857
# Reduced Flattened System
5958

@@ -62,6 +61,7 @@ flattened_system = ModelingToolkit.flatten(connected)
6261
aliased_flattened_system = alias_elimination(flattened_system)
6362

6463
@test isequal(states(aliased_flattened_system), [
64+
a
6565
lorenz1.x
6666
lorenz1.y
6767
lorenz1.z
@@ -80,22 +80,24 @@ aliased_flattened_system = alias_elimination(flattened_system)
8080
lorenz2.β
8181
]) |> isempty
8282

83-
test_equal.(equations(aliased_flattened_system), [
84-
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z,
85-
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z),
86-
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z,
87-
D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z,
88-
D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x,
89-
D(lorenz2.z) ~ lorenz2.x*lorenz2.y - lorenz2.β*lorenz2.z])
90-
91-
test_equal.(observed(aliased_flattened_system), [
92-
lorenz1.F ~ lorenz2.x + -1 * (lorenz2.y + lorenz2.z),
93-
lorenz1.u ~ lorenz1.x + lorenz1.y + -1 * lorenz1.z,
94-
lorenz2.F ~ lorenz1.x + lorenz1.y + -1 * lorenz1.z,
95-
a ~ lorenz2.y + -1 * lorenz1.x,
96-
lorenz2.u ~ lorenz2.x + -1 * (lorenz2.y + lorenz2.z),
97-
])
98-
83+
reduced_eqs = [
84+
lorenz2.y ~ a + lorenz1.x, # irreducible by alias elimination
85+
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z,
86+
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z),
87+
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z,
88+
D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z,
89+
D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x,
90+
D(lorenz2.z) ~ lorenz2.x*lorenz2.y - lorenz2.β*lorenz2.z
91+
]
92+
test_equal.(equations(aliased_flattened_system), reduced_eqs)
93+
94+
observed_eqs = [
95+
lorenz1.F ~ lorenz2.u,
96+
lorenz2.F ~ lorenz1.u,
97+
lorenz1.u ~ lorenz1.x + lorenz1.y - lorenz1.z,
98+
lorenz2.u ~ lorenz2.x - lorenz2.y - lorenz2.z,
99+
]
100+
test_equal.(observed(aliased_flattened_system), observed_eqs)
99101

100102
# issue #578
101103

0 commit comments

Comments
 (0)