56
56
57
57
function alias_elimination (sys:: ODESystem )
58
58
eqs = vcat (equations (sys), observed (sys))
59
+ neweqs = Equation[]; sizehint! (neweqs, length (eqs))
59
60
subs = Pair[]
60
61
diff_vars = filter (! isnothing, map (eqs) do eq
61
62
if isdiffeq (eq)
@@ -65,33 +66,56 @@ function alias_elimination(sys::ODESystem)
65
66
end
66
67
end ) |> Set
67
68
68
- # only substitute when the variable is algebraic
69
- del = Int[]
69
+ deps = Set ()
70
70
for (i, eq) in enumerate (eqs)
71
- isdiffeq (eq) && continue
71
+ # only substitute when the variable is algebraic
72
+ if isdiffeq (eq)
73
+ push! (neweqs, eq)
74
+ continue
75
+ end
76
+
77
+ maybe_alias = isalias = false
72
78
res_left = get_α_x (eq. lhs)
73
79
if ! isnothing (res_left) && ! (res_left[2 ] in diff_vars)
74
80
# `α x = rhs` => `x = rhs / α`
75
81
α, x = res_left
76
- push! (subs, x => _isone (α) ? eq. rhs : eq. rhs / α)
77
- push! (del, i)
82
+ sub = x => _isone (α) ? eq. rhs : eq. rhs / α
83
+ maybe_alias = true
78
84
else
79
85
res_right = get_α_x (eq. rhs)
80
86
if ! isnothing (res_right) && ! (res_right[2 ] in diff_vars)
81
87
# `lhs = β y` => `y = lhs / β`
82
88
β, y = res_right
83
- push! (subs, y => _isone (β) ? eq. lhs : β * eq. lhs)
84
- push! (del, i)
89
+ sub = y => _isone (β) ? eq. lhs : β * eq. lhs
90
+ maybe_alias = true
91
+ end
92
+ end
93
+
94
+ if maybe_alias
95
+ l, r = sub
96
+ # alias equations shouldn't introduce cycles
97
+ if ! (l in deps) && isempty (intersect (deps, vars (r)))
98
+ push! (deps, l)
99
+ push! (subs, sub)
100
+ isalias = true
85
101
end
86
102
end
103
+
104
+ if ! isalias
105
+ neweq = _iszero (eq. lhs) ? eq : 0 ~ eq. rhs - eq. lhs
106
+ push! (neweqs, neweq)
107
+ end
87
108
end
88
- deleteat! (eqs, del)
89
109
90
- eqs′ = substitute_aliases (eqs, Dict (subs))
110
+ eqs′ = substitute_aliases (neweqs, Dict (subs))
111
+
91
112
alias_vars = first .(subs)
113
+ sys_states = states (sys)
114
+ alias_eqs = alias_vars .~ last .(subs)
115
+ # alias_eqs = topsort_equations(alias_eqs, sys_states)
92
116
93
- newstates = setdiff (states (sys) , alias_vars)
94
- ODESystem (eqs′, sys. iv, newstates, parameters (sys), observed= alias_vars .~ last .(subs) )
117
+ newstates = setdiff (sys_states , alias_vars)
118
+ ODESystem (eqs′, sys. iv, newstates, parameters (sys), observed= alias_eqs )
95
119
end
96
120
97
121
"""
0 commit comments