Skip to content

Commit 37a668e

Browse files
committed
Make alias elimination work with NonlinearSystem
1 parent 3cba8b7 commit 37a668e

File tree

3 files changed

+54
-11
lines changed

3 files changed

+54
-11
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ function NonlinearSystem(eqs, states, ps;
5959
NonlinearSystem(eqs, value.(states), value.(ps), value.(pins), observed, name, systems, default_u0, default_p)
6060
end
6161

62+
independent_variable(::NonlinearSystem) = nothing
63+
6264
function calculate_jacobian(sys::NonlinearSystem;sparse=false,simplify=false)
6365
rhs = [eq.rhs for eq equations(sys)]
6466
vals = [dv for dv in states(sys)]
@@ -77,9 +79,19 @@ function generate_jacobian(sys::NonlinearSystem, vs = states(sys), ps = paramete
7779
conv = AbstractSysToExpr(sys), kwargs...)
7880
end
7981

80-
function generate_function(sys::NonlinearSystem, vs = states(sys), ps = parameters(sys); kwargs...)
81-
rhss = [eq.rhs for eq sys.eqs]
82-
return build_function(rhss, vs, ps;
82+
function generate_function(sys::NonlinearSystem, dvs = states(sys), ps = parameters(sys); kwargs...)
83+
obsvars = map(eq->eq.lhs, observed(sys))
84+
fulldvs = [dvs; obsvars]
85+
fulldvs′ = makesym.(value.(fulldvs))
86+
87+
sub = Dict(fulldvs .=> fulldvs′)
88+
# substitute x(t) by just x
89+
rhss = [substitute(deq.rhs, sub) for deq equations(sys)]
90+
obss = [makesym(value(eq.lhs)) ~ substitute(eq.rhs, sub) for eq observed(sys)]
91+
92+
dvs′ = fulldvs′[1:length(dvs)]
93+
ps′ = makesym.(value.(ps), states=())
94+
return build_function(Let(obss, rhss), dvs′, ps′;
8395
conv = AbstractSysToExpr(sys), kwargs...)
8496
end
8597

src/systems/reduction.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ function maybe_alias(lhs, rhs, diff_vars, iv, conservative)
8484
end
8585
end
8686

87-
function alias_elimination(sys::ODESystem; conservative=true)
87+
function alias_elimination(sys; conservative=true)
8888
iv = independent_variable(sys)
89-
eqs = vcat(equations(sys), observed(sys))
89+
eqs = equations(sys)
9090
diff_vars = filter(!isnothing, map(eqs) do eq
9191
if isdiffeq(eq)
9292
arguments(eq.lhs)[1]
@@ -131,11 +131,15 @@ function alias_elimination(sys::ODESystem; conservative=true)
131131
end
132132

133133
alias_vars = first.(subs)
134-
sys_states = states(sys)
135-
alias_eqs = topsort_equations(alias_vars .~ last.(subs), sys_states)
136-
137-
newstates = setdiff(sys_states, alias_vars)
138-
ODESystem(neweqs, sys.iv, newstates, parameters(sys), observed=alias_eqs)
134+
sts = states(sys)
135+
fullsts = vcat(map(eq->eq.lhs, observed(sys)), sts)
136+
alias_eqs = topsort_equations(alias_vars .~ last.(subs), fullsts)
137+
newstates = setdiff(sts, alias_vars)
138+
139+
@set! sys.eqs = neweqs
140+
@set! sys.states = newstates
141+
@set! sys.observed = [observed(sys); alias_eqs]
142+
return sys
139143
end
140144

141145
"""

test/reduction.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, OrdinaryDiffEq, Test
1+
using ModelingToolkit, OrdinaryDiffEq, Test, NonlinearSolve
22
using ModelingToolkit: topsort_equations
33

44
@variables t x(t) y(t) z(t) k(t)
@@ -179,3 +179,30 @@ let
179179
]
180180
@test ref_eqs == equations(reduced_sys)
181181
end
182+
183+
# NonlinearSystem
184+
@parameters t
185+
@variables u1(t) u2(t) u3(t) u4(t) u5(t)
186+
eqs = [
187+
2u1 ~ 3u5
188+
u2 ~ u1
189+
u3 ~ 2u1 - u2
190+
u4 ~ u2 + u3^2
191+
u5 ~ u4^2 - u1
192+
]
193+
sys = NonlinearSystem(eqs, [u1, u2, u3, u4, u5], [])
194+
reducedsys = alias_elimination(sys)
195+
@test observed(reducedsys) == [u1 ~ 3/2 * u5]
196+
197+
u0 = [
198+
u1 => 1
199+
u2 => 1
200+
u3 => 0.3
201+
u4 => 0.6
202+
u5 => 2/3
203+
]
204+
nlprob = NonlinearProblem(reducedsys, u0)
205+
reducedsol = solve(nlprob, NewtonRaphson())
206+
residual = fill(100.0, 4)
207+
nlprob.f(residual, reducedsol.u, nothing)
208+
@test all(x->abs(x) < 1e-5, residual)

0 commit comments

Comments
 (0)