56
56
57
57
function alias_elimination (sys:: ODESystem )
58
58
eqs = vcat (equations (sys), observed (sys))
59
+ neweqs = Equation[]; sizehint! (neweqs, length (eqs))
59
60
subs = Pair[]
60
61
diff_vars = filter (! isnothing, map (eqs) do eq
61
62
if isdiffeq (eq)
@@ -65,31 +66,133 @@ function alias_elimination(sys::ODESystem)
65
66
end
66
67
end ) |> Set
67
68
68
- # only substitute when the variable is algebraic
69
- del = Int[]
69
+ deps = Set ()
70
70
for (i, eq) in enumerate (eqs)
71
- isdiffeq (eq) && continue
71
+ # only substitute when the variable is algebraic
72
+ if isdiffeq (eq)
73
+ push! (neweqs, eq)
74
+ continue
75
+ end
76
+
77
+ maybe_alias = isalias = false
72
78
res_left = get_α_x (eq. lhs)
73
79
if ! isnothing (res_left) && ! (res_left[2 ] in diff_vars)
74
80
# `α x = rhs` => `x = rhs / α`
75
81
α, x = res_left
76
- push! (subs, x => _isone (α) ? eq. rhs : eq. rhs / α)
77
- push! (del, i)
82
+ sub = x => _isone (α) ? eq. rhs : eq. rhs / α
83
+ maybe_alias = true
78
84
else
79
85
res_right = get_α_x (eq. rhs)
80
86
if ! isnothing (res_right) && ! (res_right[2 ] in diff_vars)
81
87
# `lhs = β y` => `y = lhs / β`
82
88
β, y = res_right
83
- push! (subs, y => _isone (β) ? eq. lhs : β * eq. lhs)
84
- push! (del, i)
89
+ sub = y => _isone (β) ? eq. lhs : β * eq. lhs
90
+ maybe_alias = true
91
+ end
92
+ end
93
+
94
+ if maybe_alias
95
+ l, r = sub
96
+ # alias equations shouldn't introduce cycles
97
+ if ! (l in deps) && isempty (intersect (deps, vars (r)))
98
+ push! (deps, l)
99
+ push! (subs, sub)
100
+ isalias = true
85
101
end
86
102
end
103
+
104
+ if ! isalias
105
+ neweq = _iszero (eq. lhs) ? eq : 0 ~ eq. rhs - eq. lhs
106
+ push! (neweqs, neweq)
107
+ end
87
108
end
88
- deleteat! (eqs, del)
89
109
90
- eqs′ = substitute_aliases (eqs, Dict (subs))
110
+ eqs′ = substitute_aliases (neweqs, Dict (subs))
111
+
91
112
alias_vars = first .(subs)
113
+ sys_states = states (sys)
114
+ alias_eqs = alias_vars .~ last .(subs)
115
+ # alias_eqs = topsort_equations(alias_eqs, sys_states)
116
+
117
+ newstates = setdiff (sys_states, alias_vars)
118
+ ODESystem (eqs′, sys. iv, newstates, parameters (sys), observed= alias_eqs)
119
+ end
120
+
121
+ """
122
+ $(SIGNATURES)
123
+
124
+ Use Kahn's algorithm to topologically sort observed equations.
125
+
126
+ Example:
127
+ ```julia
128
+ julia> @variables t x(t) y(t) z(t) k(t)
129
+ (t, x(t), y(t), z(t), k(t))
130
+
131
+ julia> eqs = [
132
+ x ~ y + z
133
+ z ~ 2
134
+ y ~ 2z + k
135
+ ];
136
+
137
+ julia> ModelingToolkit.topsort_equations(eqs, [x, y, z, k])
138
+ 3-element Vector{Equation}:
139
+ Equation(z(t), 2)
140
+ Equation(y(t), k(t) + 2z(t))
141
+ Equation(x(t), y(t) + z(t))
142
+ ```
143
+ """
144
+ function topsort_equations (eqs, states; check= true )
145
+ graph, assigns = observed2graph (eqs, states)
146
+ neqs = length (eqs)
147
+ degrees = zeros (Int, neqs)
148
+
149
+ for 𝑠eq in 1 : length (eqs); var = assigns[𝑠eq]
150
+ for 𝑑eq in 𝑑neighbors (graph, var)
151
+ # 𝑠eq => 𝑑eq
152
+ degrees[𝑑eq] += 1
153
+ end
154
+ end
155
+
156
+ q = Queue {Int} (neqs)
157
+ for (i, d) in enumerate (degrees)
158
+ d == 0 && enqueue! (q, i)
159
+ end
160
+
161
+ idx = 0
162
+ ordered_eqs = similar (eqs, 0 ); sizehint! (ordered_eqs, neqs)
163
+ while ! isempty (q)
164
+ 𝑠eq = dequeue! (q)
165
+ idx+= 1
166
+ push! (ordered_eqs, eqs[𝑠eq])
167
+ var = assigns[𝑠eq]
168
+ for 𝑑eq in 𝑑neighbors (graph, var)
169
+ degree = degrees[𝑑eq] = degrees[𝑑eq] - 1
170
+ degree == 0 && enqueue! (q, 𝑑eq)
171
+ end
172
+ end
173
+
174
+ (check && idx != neqs) && throw (ArgumentError (" The equations have at least one cycle." ))
175
+
176
+ return ordered_eqs
177
+ end
178
+
179
+ function observed2graph (eqs, states)
180
+ graph = BipartiteGraph (length (eqs), length (states))
181
+ v2j = Dict (states .=> 1 : length (states))
182
+
183
+ # `assigns: eq -> var`, `eq` defines `var`
184
+ assigns = similar (eqs, Int)
185
+
186
+ for (i, eq) in enumerate (eqs)
187
+ lhs_j = get (v2j, eq. lhs, nothing )
188
+ lhs_j === nothing && throw (ArgumentError (" The lhs $(eq. lhs) of $eq , doesn't appear in states." ))
189
+ assigns[i] = lhs_j
190
+ vs = vars (eq. rhs)
191
+ for v in vs
192
+ j = get (v2j, v, nothing )
193
+ j != = nothing && add_edge! (graph, i, j)
194
+ end
195
+ end
92
196
93
- newstates = setdiff (states (sys), alias_vars)
94
- ODESystem (eqs′, sys. iv, newstates, parameters (sys), observed= alias_vars .~ last .(subs))
197
+ return graph, assigns
95
198
end
0 commit comments