Skip to content

Commit 572c24d

Browse files
committed
Add Kahn's algorithm to topologically sort observed equations
1 parent deb821c commit 572c24d

File tree

2 files changed

+81
-1
lines changed

2 files changed

+81
-1
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Latexify, Unitful, ArrayInterface
66
using MacroTools
77
using UnPack: @unpack
88
using DiffEqJump
9-
using DataStructures: OrderedDict, OrderedSet
9+
using DataStructures
1010
using SpecialFunctions, NaNMath
1111
using RuntimeGeneratedFunctions
1212
using Base.Threads

src/systems/reduction.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,83 @@ function alias_elimination(sys::ODESystem)
9393
newstates = setdiff(states(sys), alias_vars)
9494
ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_vars .~ last.(subs))
9595
end
96+
97+
"""
98+
$(SIGNATURES)
99+
100+
Use Kahn's algorithm to topologically sort observed equations.
101+
102+
Example:
103+
```julia
104+
julia> @variables t x(t) y(t) z(t) k(t)
105+
(t, x(t), y(t), z(t), k(t))
106+
107+
julia> eqs = [
108+
x ~ y + z
109+
z ~ 2
110+
y ~ 2z + k
111+
];
112+
113+
julia> ModelingToolkit.topsort_observed(eqs, [x, y, z, k])
114+
3-element Vector{Equation}:
115+
Equation(z(t), 2)
116+
Equation(y(t), k(t) + 2z(t))
117+
Equation(x(t), y(t) + z(t))
118+
```
119+
"""
120+
function topsort_observed(eqs, states)
121+
graph, assigns, v2j = observed2graph(eqs, states)
122+
neqs = length(eqs)
123+
degrees = zeros(Int, neqs)
124+
125+
for 𝑠eq in 1:length(eqs); var = assigns[𝑠eq]
126+
for 𝑑eq in 𝑑neighbors(graph, var)
127+
# 𝑠eq => 𝑑eq
128+
degrees[𝑑eq] += 1
129+
end
130+
end
131+
132+
q = Queue{Int}(neqs)
133+
for (i, d) in enumerate(degrees)
134+
d == 0 && enqueue!(q, i)
135+
end
136+
137+
idx = 0
138+
order = zeros(Int, neqs)
139+
while !isempty(q)
140+
j = dequeue!(q)
141+
order[idx+=1] = j
142+
for 𝑠eq in 1:length(eqs); var = assigns[𝑠eq]
143+
for 𝑑eq in 𝑑neighbors(graph, var)
144+
# 𝑠eq => 𝑑eq
145+
degree = degrees[𝑑eq] = degrees[𝑑eq] - 1
146+
degree == 0 && enqueue!(q, 𝑑eq)
147+
end
148+
end
149+
end
150+
151+
idx == neqs || throw(ArgumentError("There's a cycle in obversed equations."))
152+
153+
return eqs[order]
154+
end
155+
156+
function observed2graph(eqs, states)
157+
graph = BipartiteGraph(length(eqs), length(states))
158+
v2j = Dict(states .=> 1:length(states))
159+
160+
# `eqs[eq_idx]` defines `assigns[eq_idx]` var
161+
assigns = Vector{Any}(undef, length(eqs))
162+
163+
for (i, eq) in enumerate(eqs)
164+
lhs_j = get(v2j, eq.lhs, nothing)
165+
lhs_j === nothing && throw(ArgumentError("The lhs $lhs of $eq, doesn't appear in states."))
166+
assigns[i] = lhs_j
167+
vs = vars(eq.rhs)
168+
for v in vs
169+
j = get(v2j, v, nothing)
170+
j !== nothing && add_edge!(graph, i, j)
171+
end
172+
end
173+
174+
return graph, assigns, v2j
175+
end

0 commit comments

Comments
 (0)