Skip to content

Commit 5ed420e

Browse files
authored
Merge pull request #751 from SciML/ms/observed_lowering
Convert observed equations as assignments when lowering
2 parents c438d20 + 7cd32de commit 5ed420e

File tree

7 files changed

+275
-217
lines changed

7 files changed

+275
-217
lines changed

src/build_function.jl

Lines changed: 187 additions & 151 deletions
Large diffs are not rendered by default.

src/systems/abstractsystem.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,16 @@ function Base.getproperty(sys::AbstractSystem, name::Symbol)
159159
throw(error("Variable $name does not exist"))
160160
end
161161

162-
renamespace(namespace,name) = Symbol(namespace,:₊,name)
163-
164-
function renamespace(namespace, x::Sym)
165-
Sym{symtype(x)}(renamespace(namespace,x.name))
166-
end
167-
168-
function renamespace(namespace, x::Term)
169-
renamespace(namespace, operation(x))(arguments(x)...)
162+
function renamespace(namespace, x)
163+
if x isa Num
164+
renamespace(namespace, value(x))
165+
elseif istree(x)
166+
renamespace(namespace, operation(x))(arguments(x)...)
167+
elseif x isa Sym
168+
Sym{symtype(x)}(renamespace(namespace,nameof(x)))
169+
else
170+
Symbol(namespace,:₊,x)
171+
end
170172
end
171173

172174
function namespace_variables(sys::AbstractSystem)
@@ -213,7 +215,7 @@ end
213215
independent_variable(sys::AbstractSystem) = sys.iv
214216
function states(sys::AbstractSystem)
215217
unique(isempty(sys.systems) ?
216-
setdiff(sys.states, value.(sys.pins)) :
218+
sys.states :
217219
[sys.states;reduce(vcat,namespace_variables.(sys.systems))])
218220
end
219221
parameters(sys::AbstractSystem) = isempty(sys.systems) ? sys.ps : [sys.ps;reduce(vcat,namespace_parameters.(sys.systems))]

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,20 @@ end
6969

7070
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
7171
# optimization
72-
dvs′ = makesym.(value.(dvs), states=dvs)
73-
ps′ = makesym.(value.(ps), states=dvs)
72+
obsvars = map(eq->eq.lhs, observed(sys))
73+
fulldvs = [dvs; obsvars]
74+
fulldvs′ = makesym.(value.(fulldvs))
7475

75-
sub = Dict(dvs .=> dvs′)
76+
sub = Dict(fulldvs .=> fulldvs′)
7677
# substitute x(t) by just x
7778
rhss = [substitute(deq.rhs, sub) for deq equations(sys)]
78-
return build_function(rhss, dvs′, ps′, sys.iv;
79+
obss = [makesym(value(eq.lhs)) ~ substitute(eq.rhs, sub) for eq observed(sys)]
80+
81+
dvs′ = fulldvs′[1:length(dvs)]
82+
ps′ = makesym.(value.(ps), states=())
83+
84+
# TODO: add an optional check on the ordering of observed equations
85+
return build_function(Let(obss, rhss), dvs′, ps′, sys.iv;
7986
conv = ODEToExpr(sys),kwargs...)
8087
end
8188

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ end
168168

169169
function collect_var!(states, parameters, var, iv)
170170
isequal(var, iv) && return nothing
171-
if isparameter(var) || isparameter(operation(var))
171+
if isparameter(var) || (istree(var) && isparameter(operation(var)))
172172
push!(parameters, var)
173173
else
174174
push!(states, var)

src/systems/reduction.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ function flatten(sys::ODESystem)
88
independent_variable(sys),
99
states(sys),
1010
parameters(sys),
11+
pins=pins(sys),
1112
observed=observed(sys))
1213
end
1314
end
@@ -56,8 +57,6 @@ end
5657

5758
function alias_elimination(sys::ODESystem)
5859
eqs = vcat(equations(sys), observed(sys))
59-
neweqs = Equation[]; sizehint!(neweqs, length(eqs))
60-
subs = Pair[]
6160
diff_vars = filter(!isnothing, map(eqs) do eq
6261
if isdiffeq(eq)
6362
arguments(eq.lhs)[1]
@@ -67,6 +66,9 @@ function alias_elimination(sys::ODESystem)
6766
end) |> Set
6867

6968
deps = Set()
69+
subs = Pair[]
70+
neweqs = Equation[]; sizehint!(neweqs, length(eqs))
71+
7072
for (i, eq) in enumerate(eqs)
7173
# only substitute when the variable is algebraic
7274
if isdiffeq(eq)
@@ -107,15 +109,12 @@ function alias_elimination(sys::ODESystem)
107109
end
108110
end
109111

110-
eqs′ = substitute_aliases(neweqs, Dict(subs))
111-
112112
alias_vars = first.(subs)
113113
sys_states = states(sys)
114-
alias_eqs = alias_vars .~ last.(subs)
115-
#alias_eqs = topsort_equations(alias_eqs, sys_states)
114+
alias_eqs = topsort_equations(alias_vars .~ last.(subs), sys_states)
116115

117116
newstates = setdiff(sys_states, alias_vars)
118-
ODESystem(eqs′, sys.iv, newstates, parameters(sys), observed=alias_eqs)
117+
ODESystem(neweqs, sys.iv, newstates, parameters(sys), observed=alias_eqs)
119118
end
120119

121120
"""

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ Symbol("z⦗t⦘")
222222
"""
223223
function tosymbol(t::Term; states=nothing, escape=true)
224224
if operation(t) isa Sym
225-
if states !== nothing && !(any(isequal(t), states))
225+
if states !== nothing && !(t in states)
226226
return nameof(operation(t))
227227
end
228228
op = nameof(operation(t))

test/reduction.jl

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -29,99 +29,113 @@ D = Differential(t)
2929

3030
test_equal(a, b) = @test isequal(simplify(a, polynorm=true), simplify(b, polynorm=true))
3131

32-
eqs = [D(x) ~ σ*(y-x),
33-
D(y) ~ x*-z)-y + β,
34-
0 ~ sin(z) - x + y,
35-
sin(u) ~ x + y,
36-
2β ~ 2,
37-
x ~ a,
32+
eqs = [
33+
D(x) ~ σ*(y-x)
34+
D(y) ~ x*-z)-y + β
35+
0 ~ sin(z) - x + y
36+
sin(u) ~ x + y
37+
x ~ a
3838
]
3939

40-
lorenz1 = ODESystem(eqs,t,[u,x,y,z,a],[σ,ρ,β],name=:lorenz1)
40+
lorenz1 = ODESystem(eqs,t,name=:lorenz1)
4141

4242
lorenz1_aliased = alias_elimination(lorenz1)
4343
reduced_eqs = [
4444
D(x) ~ σ * (y - x),
45-
D(y) ~ x*-z)-y + 1,
45+
D(y) ~ x*-z)-y + β,
4646
0 ~ sin(z) - x + y,
4747
0 ~ x + y - sin(u),
4848
]
4949
test_equal.(equations(lorenz1_aliased), reduced_eqs)
50-
test_equal.(states(lorenz1_aliased), [u, x, y, z])
50+
@test isempty(setdiff(states(lorenz1_aliased), [u, x, y, z]))
5151
test_equal.(observed(lorenz1_aliased), [
52-
β ~ 1,
5352
a ~ x,
5453
])
5554

5655
# Multi-System Reduction
5756

57+
@variables s
5858
eqs1 = [
5959
D(x) ~ σ*(y-x) + F,
6060
D(y) ~ x*-z)-u,
6161
D(z) ~ x*y - β*z,
6262
u ~ x + y - z,
6363
]
6464

65-
lorenz1 = ODESystem(eqs1,pins=[F],name=:lorenz1)
66-
67-
eqs2 = [
68-
D(x) ~ F,
69-
D(y) ~ x*-z)-x,
70-
D(z) ~ x*y - β*z,
71-
u ~ x - y - z
72-
]
65+
lorenz = name -> ODESystem(eqs1,t,pins=[F],name=name)
66+
lorenz1 = lorenz(:lorenz1)
67+
lorenz2 = lorenz(:lorenz2)
7368

74-
lorenz2 = ODESystem(eqs2,pins=[F],name=:lorenz2)
75-
76-
connected = ODESystem([lorenz2.y ~ a + lorenz1.x,
77-
lorenz1.F ~ lorenz2.u,
78-
lorenz2.F ~ lorenz1.u],t,[a],[],systems=[lorenz1,lorenz2])
69+
connected = ODESystem([s ~ a + lorenz1.x
70+
lorenz2.y ~ s
71+
lorenz1.F ~ lorenz2.u
72+
lorenz2.F ~ lorenz1.u],t,systems=[lorenz1,lorenz2])
7973

8074
# Reduced Flattened System
8175

8276
flattened_system = ModelingToolkit.flatten(connected)
8377

8478
aliased_flattened_system = alias_elimination(flattened_system)
8579

86-
@test isequal(states(aliased_flattened_system), [
80+
@test setdiff(states(aliased_flattened_system), [
8781
a
8882
lorenz1.x
8983
lorenz1.y
9084
lorenz1.z
9185
lorenz2.x
9286
lorenz2.y
9387
lorenz2.z
94-
])
88+
]) |> isempty
9589

9690
@test setdiff(parameters(aliased_flattened_system), [
9791
lorenz1.σ
9892
lorenz1.ρ
9993
lorenz1.β
100-
lorenz1.F
101-
lorenz2.F
94+
lorenz2.σ
10295
lorenz2.ρ
10396
lorenz2.β
10497
]) |> isempty
10598

10699
reduced_eqs = [
107-
0 ~ a + lorenz1.x - lorenz2.y, # irreducible by alias elimination
108-
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z,
109-
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z),
110-
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z,
111-
D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z,
112-
D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x,
113-
D(lorenz2.z) ~ lorenz2.x*lorenz2.y - lorenz2.β*lorenz2.z
100+
0 ~ s - lorenz2.y
101+
D(lorenz1.x) ~ lorenz1.F + lorenz1.σ*(lorenz1.y + -1lorenz1.x)
102+
D(lorenz1.y) ~ -1lorenz1.u + lorenz1.x*(lorenz1.ρ + -1lorenz1.z)
103+
D(lorenz1.z) ~ lorenz1.x*lorenz1.y + -1lorenz1.β*lorenz1.z
104+
D(lorenz2.x) ~ lorenz2.F + lorenz2.σ*(lorenz2.y + -1lorenz2.x)
105+
D(lorenz2.y) ~ -1lorenz2.u + lorenz2.x*(lorenz2.ρ + -1lorenz2.z)
106+
D(lorenz2.z) ~ lorenz2.x*lorenz2.y + -1lorenz2.β*lorenz2.z
114107
]
115108
test_equal.(equations(aliased_flattened_system), reduced_eqs)
116109

117110
observed_eqs = [
118-
lorenz1.F ~ lorenz2.u,
119-
lorenz2.F ~ lorenz1.u,
120-
lorenz1.u ~ lorenz1.x + lorenz1.y - lorenz1.z,
121-
lorenz2.u ~ lorenz2.x - lorenz2.y - lorenz2.z,
111+
s ~ a + lorenz1.x
112+
lorenz1.u ~ lorenz1.x + lorenz1.y - lorenz1.z
113+
lorenz2.u ~ lorenz2.x + lorenz2.y - lorenz2.z
114+
lorenz2.F ~ lorenz1.u
115+
lorenz1.F ~ lorenz2.u
122116
]
123117
test_equal.(observed(aliased_flattened_system), observed_eqs)
124118

119+
pp = [
120+
lorenz1.σ => 10
121+
lorenz1.ρ => 28
122+
lorenz1.β => 8/3
123+
lorenz2.σ => 10
124+
lorenz2.ρ => 28
125+
lorenz2.β => 8/3
126+
]
127+
u0 = [
128+
a => 1.0
129+
lorenz1.x => 1.0
130+
lorenz1.y => 0.0
131+
lorenz1.z => 0.0
132+
lorenz2.x => 1.0
133+
lorenz2.y => 0.0
134+
lorenz2.z => 0.0
135+
]
136+
prob1 = ODEProblem(aliased_flattened_system, u0, (0.0, 100.0), pp)
137+
solve(prob1, Rodas5())
138+
125139
# issue #578
126140

127141
let
@@ -131,10 +145,10 @@ let
131145
D(x) ~ x + y
132146
x ~ y
133147
]
134-
sys = ODESystem(eqs, t, [x], [])
148+
sys = ODESystem(eqs, t)
135149
asys = alias_elimination(flatten(sys))
136150

137-
test_equal.(asys.eqs, [D(x) ~ 2x])
151+
test_equal.(asys.eqs, [D(x) ~ x + y])
138152
test_equal.(asys.observed, [y ~ x])
139153
end
140154

@@ -149,17 +163,17 @@ let
149163
@parameters k_P
150164
pc = ODESystem(Equation[u_c ~ k_P * y_c], t, pins=[y_c], name=:pc)
151165
connections = [
152-
ol.u ~ pc.u_c,
166+
ol.u ~ pc.u_c
153167
pc.y_c ~ ol.y
154168
]
155169
connected = ODESystem(connections, t, systems=[ol, pc])
156170
@test equations(connected) isa Vector{Equation}
157171
sys = flatten(connected)
158172
reduced_sys = alias_elimination(sys)
159173
ref_eqs = [
160-
D(ol.x) ~ ol.a*ol.x + ol.b*pc.u_c
161-
0 ~ ol.c*ol.x + ol.d*pc.u_c - ol.y
162-
0 ~ pc.k_P*ol.y - pc.u_c
174+
D(ol.x) ~ ol.a*ol.x + ol.b*ol.u
175+
0 ~ ol.c*ol.x + ol.d*ol.u + -1ol.y
176+
0 ~ pc.k_P*pc.y_c + -1pc.u_c
163177
]
164178
@test ref_eqs == equations(reduced_sys)
165179
end

0 commit comments

Comments
 (0)