Skip to content

Commit 23ab27a

Browse files
authored
Merge pull request #750 from SciML/myb/topsort
Add topsort and improved alias elimination
2 parents deb821c + a169988 commit 23ab27a

File tree

4 files changed

+155
-22
lines changed

4 files changed

+155
-22
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Latexify, Unitful, ArrayInterface
66
using MacroTools
77
using UnPack: @unpack
88
using DiffEqJump
9-
using DataStructures: OrderedDict, OrderedSet
9+
using DataStructures
1010
using SpecialFunctions, NaNMath
1111
using RuntimeGeneratedFunctions
1212
using Base.Threads

src/equations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ struct Equation
1414
end
1515
Base.:(==)(a::Equation, b::Equation) = all(isequal.((a.lhs, a.rhs), (b.lhs, b.rhs)))
1616
Base.hash(a::Equation, salt::UInt) = hash(a.lhs, hash(a.rhs, salt))
17+
Base.show(io::IO, eq::Equation) = print(io, eq.lhs, " ~ ", eq.rhs)
1718

1819
SymbolicUtils.simplify(x::Equation; kw...) = simplify(x.lhs; kw...) ~ simplify(x.rhs; kw...)
1920

src/systems/reduction.jl

Lines changed: 114 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ end
5656

5757
function alias_elimination(sys::ODESystem)
5858
eqs = vcat(equations(sys), observed(sys))
59+
neweqs = Equation[]; sizehint!(neweqs, length(eqs))
5960
subs = Pair[]
6061
diff_vars = filter(!isnothing, map(eqs) do eq
6162
if isdiffeq(eq)
@@ -65,31 +66,133 @@ function alias_elimination(sys::ODESystem)
6566
end
6667
end) |> Set
6768

68-
# only substitute when the variable is algebraic
69-
del = Int[]
69+
deps = Set()
7070
for (i, eq) in enumerate(eqs)
71-
isdiffeq(eq) && continue
71+
# only substitute when the variable is algebraic
72+
if isdiffeq(eq)
73+
push!(neweqs, eq)
74+
continue
75+
end
76+
77+
maybe_alias = isalias = false
7278
res_left = get_α_x(eq.lhs)
7379
if !isnothing(res_left) && !(res_left[2] in diff_vars)
7480
# `α x = rhs` => `x = rhs / α`
7581
α, x = res_left
76-
push!(subs, x => _isone(α) ? eq.rhs : eq.rhs / α)
77-
push!(del, i)
82+
sub = x => _isone(α) ? eq.rhs : eq.rhs / α
83+
maybe_alias = true
7884
else
7985
res_right = get_α_x(eq.rhs)
8086
if !isnothing(res_right) && !(res_right[2] in diff_vars)
8187
# `lhs = β y` => `y = lhs / β`
8288
β, y = res_right
83-
push!(subs, y => _isone(β) ? eq.lhs : β * eq.lhs)
84-
push!(del, i)
89+
sub = y => _isone(β) ? eq.lhs : β * eq.lhs
90+
maybe_alias = true
91+
end
92+
end
93+
94+
if maybe_alias
95+
l, r = sub
96+
# alias equations shouldn't introduce cycles
97+
if !(l in deps) && isempty(intersect(deps, vars(r)))
98+
push!(deps, l)
99+
push!(subs, sub)
100+
isalias = true
85101
end
86102
end
103+
104+
if !isalias
105+
neweq = _iszero(eq.lhs) ? eq : 0 ~ eq.rhs - eq.lhs
106+
push!(neweqs, neweq)
107+
end
87108
end
88-
deleteat!(eqs, del)
89109

90-
eqs′ = substitute_aliases(eqs, Dict(subs))
110+
eqs′ = substitute_aliases(neweqs, Dict(subs))
111+
91112
alias_vars = first.(subs)
113+
sys_states = states(sys)
114+
alias_eqs = alias_vars .~ last.(subs)
115+
#alias_eqs = topsort_equations(alias_eqs, sys_states)
116+
117+
newstates = setdiff(sys_states, alias_vars)
118+
ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_eqs)
119+
end
120+
121+
"""
122+
$(SIGNATURES)
123+
124+
Use Kahn's algorithm to topologically sort observed equations.
125+
126+
Example:
127+
```julia
128+
julia> @variables t x(t) y(t) z(t) k(t)
129+
(t, x(t), y(t), z(t), k(t))
130+
131+
julia> eqs = [
132+
x ~ y + z
133+
z ~ 2
134+
y ~ 2z + k
135+
];
136+
137+
julia> ModelingToolkit.topsort_equations(eqs, [x, y, z, k])
138+
3-element Vector{Equation}:
139+
Equation(z(t), 2)
140+
Equation(y(t), k(t) + 2z(t))
141+
Equation(x(t), y(t) + z(t))
142+
```
143+
"""
144+
function topsort_equations(eqs, states; check=true)
145+
graph, assigns = observed2graph(eqs, states)
146+
neqs = length(eqs)
147+
degrees = zeros(Int, neqs)
148+
149+
for 𝑠eq in 1:length(eqs); var = assigns[𝑠eq]
150+
for 𝑑eq in 𝑑neighbors(graph, var)
151+
# 𝑠eq => 𝑑eq
152+
degrees[𝑑eq] += 1
153+
end
154+
end
155+
156+
q = Queue{Int}(neqs)
157+
for (i, d) in enumerate(degrees)
158+
d == 0 && enqueue!(q, i)
159+
end
160+
161+
idx = 0
162+
ordered_eqs = similar(eqs, 0); sizehint!(ordered_eqs, neqs)
163+
while !isempty(q)
164+
𝑠eq = dequeue!(q)
165+
idx+=1
166+
push!(ordered_eqs, eqs[𝑠eq])
167+
var = assigns[𝑠eq]
168+
for 𝑑eq in 𝑑neighbors(graph, var)
169+
degree = degrees[𝑑eq] = degrees[𝑑eq] - 1
170+
degree == 0 && enqueue!(q, 𝑑eq)
171+
end
172+
end
173+
174+
(check && idx != neqs) && throw(ArgumentError("The equations have at least one cycle."))
175+
176+
return ordered_eqs
177+
end
178+
179+
function observed2graph(eqs, states)
180+
graph = BipartiteGraph(length(eqs), length(states))
181+
v2j = Dict(states .=> 1:length(states))
182+
183+
# `assigns: eq -> var`, `eq` defines `var`
184+
assigns = similar(eqs, Int)
185+
186+
for (i, eq) in enumerate(eqs)
187+
lhs_j = get(v2j, eq.lhs, nothing)
188+
lhs_j === nothing && throw(ArgumentError("The lhs $(eq.lhs) of $eq, doesn't appear in states."))
189+
assigns[i] = lhs_j
190+
vs = vars(eq.rhs)
191+
for v in vs
192+
j = get(v2j, v, nothing)
193+
j !== nothing && add_edge!(graph, i, j)
194+
end
195+
end
92196

93-
newstates = setdiff(states(sys), alias_vars)
94-
ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_vars .~ last.(subs))
197+
return graph, assigns
95198
end

test/reduction.jl

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,27 @@
11
using ModelingToolkit, OrdinaryDiffEq, Test
2+
using ModelingToolkit: topsort_equations
3+
4+
@variables t x(t) y(t) z(t) k(t)
5+
eqs = [
6+
x ~ y + z
7+
z ~ 2
8+
y ~ 2z + k
9+
]
10+
11+
sorted_eq = topsort_equations(eqs, [x, y, z, k])
12+
13+
ref_eq = [
14+
z ~ 2
15+
y ~ 2z + k
16+
x ~ y + z
17+
]
18+
@test ref_eq == sorted_eq
19+
20+
@test_throws ArgumentError topsort_equations([
21+
x ~ y + z
22+
z ~ 2
23+
y ~ 2z + x
24+
], [x, y, z, k])
225

326
@parameters t σ ρ β
427
@variables x(t) y(t) z(t) a(t) u(t) F(t)
@@ -21,7 +44,7 @@ reduced_eqs = [
2144
D(x) ~ σ * (y - x),
2245
D(y) ~ x*-z)-y + 1,
2346
0 ~ sin(z) - x + y,
24-
sin(u) ~ x + y,
47+
0 ~ x + y - sin(u),
2548
]
2649
test_equal.(equations(lorenz1_aliased), reduced_eqs)
2750
test_equal.(states(lorenz1_aliased), [u, x, y, z])
@@ -81,7 +104,7 @@ aliased_flattened_system = alias_elimination(flattened_system)
81104
]) |> isempty
82105

83106
reduced_eqs = [
84-
lorenz2.y ~ a + lorenz1.x, # irreducible by alias elimination
107+
0 ~ a + lorenz1.x - lorenz2.y, # irreducible by alias elimination
85108
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z,
86109
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z),
87110
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z,
@@ -115,22 +138,28 @@ let
115138
test_equal.(asys.observed, [y ~ x])
116139
end
117140

118-
# issue #716
141+
# issue #724 and #716
119142
let
120143
@parameters t
121144
D = Differential(t)
122145
@variables x(t), u(t), y(t)
123146
@parameters a, b, c, d
124-
ol = ODESystem([D(x) ~ a * x + b * u, y ~ c * x], t, name=:ol)
147+
ol = ODESystem([D(x) ~ a * x + b * u; y ~ c * x + d * u], t, pins=[u], name=:ol)
125148
@variables u_c(t), y_c(t)
126149
@parameters k_P
127-
pc = ODESystem(Equation[], t, pins=[y_c], observed = [u_c ~ k_P * y_c], name=:pc)
150+
pc = ODESystem(Equation[u_c ~ k_P * y_c], t, pins=[y_c], name=:pc)
128151
connections = [
129-
ol.u ~ pc.u_c,
130-
y_c ~ ol.y
131-
]
152+
ol.u ~ pc.u_c,
153+
pc.y_c ~ ol.y
154+
]
132155
connected = ODESystem(connections, t, systems=[ol, pc])
133-
134156
@test equations(connected) isa Vector{Equation}
135-
@test_nowarn flatten(connected)
157+
sys = flatten(connected)
158+
reduced_sys = alias_elimination(sys)
159+
ref_eqs = [
160+
D(ol.x) ~ ol.a*ol.x + ol.b*pc.u_c
161+
0 ~ ol.c*ol.x + ol.d*pc.u_c - ol.y
162+
0 ~ pc.k_P*ol.y - pc.u_c
163+
]
164+
@test ref_eqs == equations(reduced_sys)
136165
end

0 commit comments

Comments
 (0)