Skip to content

Commit 336e5cb

Browse files
authored
Merge pull request #837 from SciML/myb/obsfix
Handle dependencies of observed variables properly
2 parents d3ecbf1 + cecd7f7 commit 336e5cb

File tree

11 files changed

+85
-44
lines changed

11 files changed

+85
-44
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "5.9.0"
4+
version = "5.9.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

docs/src/tutorials/acausal_components.md

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ equalities before solving. Let's see this in action.
1212
## Copy-Paste Example
1313

1414
```julia
15-
using ModelingToolkit, Plots
15+
using ModelingToolkit, Plots, DifferentialEquations
1616

1717
@parameters t
1818

@@ -23,15 +23,15 @@ function Pin(;name)
2323
end
2424

2525
function Ground(;name)
26-
g = Pin(:g)
26+
@named g = Pin()
2727
eqs = [g.v ~ 0]
2828
ODESystem(eqs, t, [], [], systems=[g], name=name)
2929
end
3030

3131
function Resistor(;name, R = 1.0)
3232
val = R
33-
p = Pin(:p)
34-
n = Pin(:n)
33+
@named p = Pin()
34+
@named n = Pin()
3535
@variables v(t)
3636
@parameters R
3737
eqs = [
@@ -44,8 +44,8 @@ end
4444

4545
function Capacitor(; name, C = 1.0)
4646
val = C
47-
p = Pin(:p)
48-
n = Pin(:n)
47+
@named p = Pin()
48+
@named n = Pin()
4949
@variables v(t)
5050
@parameters C
5151
D = Differential(t)
@@ -59,8 +59,8 @@ end
5959

6060
function ConstantVoltage(;name, V = 1.0)
6161
val = V
62-
p = Pin(:p)
63-
n = Pin(:n)
62+
@named p = Pin()
63+
@named n = Pin()
6464
@parameters V
6565
eqs = [
6666
V ~ p.v - n.v
@@ -135,7 +135,7 @@ names to correspond to duplicates of this topology with unique variables.
135135
One can then construct a `Pin` like:
136136

137137
```julia
138-
Pin(:mypin1)
138+
Pin(name=:mypin1)
139139
```
140140

141141
or equivalently using the `@named` helper macro:
@@ -151,7 +151,7 @@ that the voltage in such a `Pin` is equal to zero. This gives:
151151

152152
```julia
153153
function Ground(;name)
154-
g = Pin(:g)
154+
@named g = Pin()
155155
eqs = [g.v ~ 0]
156156
ODESystem(eqs, t, [], [], systems=[g], name=name)
157157
end
@@ -167,8 +167,8 @@ zero. This leads to our resistor equations:
167167
```julia
168168
function Resistor(;name, R = 1.0)
169169
val = R
170-
p = Pin(:p)
171-
n = Pin(:n)
170+
@named p = Pin()
171+
@named n = Pin()
172172
@variables v(t)
173173
@parameters R
174174
eqs = [
@@ -190,8 +190,8 @@ Using our knowledge of circuits we similarly construct the Capacitor:
190190
```julia
191191
function Capacitor(; name, C = 1.0)
192192
val = C
193-
p = Pin(:p)
194-
n = Pin(:n)
193+
@named p = Pin()
194+
@named n = Pin()
195195
@variables v(t)
196196
@parameters C
197197
D = Differential(t)
@@ -212,8 +212,8 @@ model this as:
212212
```julia
213213
function ConstantVoltage(;name, V = 1.0)
214214
val = V
215-
p = Pin(:p)
216-
n = Pin(:n)
215+
@named p = Pin()
216+
@named n = Pin()
217217
@parameters V
218218
eqs = [
219219
V ~ p.v - n.v
@@ -406,10 +406,15 @@ observed(sys)
406406
capacitor₊n₊v(t) ~ 0.0
407407
source₊n₊v(t) ~ 0.0
408408
ground₊g₊i(t) ~ 0.0
409-
409+
source₊n₊i(t) ~ capacitor₊p₊i(t)
410+
source₊p₊i(t) ~ -capacitor₊p₊i(t)
411+
capacitor₊n₊i(t) ~ -capacitor₊p₊i(t)
412+
resistor₊n₊i(t) ~ -capacitor₊p₊i(t)
410413
ground₊g₊v(t) ~ 0.0
411414
source₊p₊v(t) ~ source₊V
412415
capacitor₊p₊v(t) ~ capacitor₊v(t)
416+
resistor₊p₊v(t) ~ source₊p₊v(t)
417+
resistor₊n₊v(t) ~ capacitor₊p₊v(t)
413418
resistor₊v(t) ~ -((capacitor₊p₊v(t)) - (source₊p₊v(t)))
414419
```
415420

@@ -428,5 +433,5 @@ sol[resistor.v]
428433
or we can plot the timeseries of the resistor's voltage:
429434

430435
```julia
431-
plot(sol,vars=(resistor.v,))
436+
plot(sol, vars=[resistor.v])
432437
```

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using ModelingToolkit: ODESystem, var_from_nested_derivative, Differential,
1616
states, equations, vars, Symbolic, diff2term, value,
1717
operation, arguments, Sym, Term, simplify, solve_for,
1818
isdiffeq, isdifferential,
19-
get_structure, default_u0, default_p
19+
get_structure, get_reduced_states, default_u0, default_p
2020

2121
using ModelingToolkit.BipartiteGraphs
2222
using ModelingToolkit.SystemStructures

src/structural_transformation/codegen.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,14 @@ function build_observed_function(
272272
required_algvars = Set(intersect(algvars, syms_set))
273273
obs = observed(sys)
274274
observed_idx = Dict(map(x->x.lhs, obs) .=> 1:length(obs))
275-
for sym in syms
276-
idx = get(observed_idx, sym, nothing)
275+
# FIXME: this is a rather rough estimate of dependencies.
276+
maxidx = 0
277+
for (i, s) in enumerate(syms)
278+
idx = get(observed_idx, s, nothing)
277279
idx === nothing && continue
280+
idx > maxidx && (maxidx = idx)
281+
end
282+
for idx in 1:maxidx
278283
vs = vars(obs[idx].rhs)
279284
union!(required_algvars, intersect(algvars, vs))
280285
end
@@ -308,7 +313,10 @@ function build_observed_function(
308313
],
309314
[],
310315
Let(
311-
collect(Iterators.flatten(solves)),
316+
[
317+
collect(Iterators.flatten(solves))
318+
map(eq -> eq.lhseq.rhs, obs[1:maxidx])
319+
],
312320
isscalar ? output[1] : MakeArray(output, output_type)
313321
)
314322
) |> Code.toexpr

src/structural_transformation/tearing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ function tearing_reassemble(sys; simplify=false)
180180
@set! sys.structure.algeqs = newalgeqs
181181
@set! sys.eqs = neweqs
182182
@set! sys.states = newstates
183+
@set! sys.reduced_states = [get_reduced_states(sys); solvars]
183184
@set! sys.observed = vcat(observed(sys), obseqs)
184185
return sys
185186
end

src/systems/abstractsystem.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ for prop in [
154154
:inequality_constraints
155155
:controls
156156
:loss
157+
:reduced_states
157158
]
158159
fname1 = Symbol(:get_, prop)
159160
fname2 = Symbol(:has_, prop)
@@ -467,6 +468,13 @@ end
467468
"""
468469
$(SIGNATURES)
469470
470-
Structurally simplify algebraic equations in a system.
471+
Structurally simplify algebraic equations in a system and compute the
472+
topological sort of the observed equations.
471473
"""
472-
structural_simplify(sys::AbstractSystem) = tearing(alias_elimination(sys))
474+
function structural_simplify(sys::AbstractSystem)
475+
sys = tearing(alias_elimination(sys))
476+
s = structure(sys)
477+
fullstates = [get_reduced_states(sys); states(sys)]
478+
@set! sys.observed = topsort_equations(observed(sys), fullstates)
479+
return sys
480+
end

src/systems/alias_elimination.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ using SymbolicUtils: Rewriters
33
const KEEP = typemin(Int)
44

55
function alias_elimination(sys)
6-
sys = flatten(sys)
7-
s = get_structure(sys)
8-
if !(s isa SystemStructure)
9-
sys = initialize_system_structure(sys)
10-
s = structure(sys)
11-
end
6+
# FIXME: update `structure` too
7+
#sys = flatten(sys)
8+
#s = get_structure(sys)
9+
#if !(s isa SystemStructure)
10+
sys = initialize_system_structure(sys)
11+
s = structure(sys)
12+
#end
1213
is_linear_equations, eadj, cadj = find_linear_equations(sys)
1314

1415
v_eliminated, v_types, n_null_vars, degenerate_equations, linear_equations = alias_eliminate_graph(
@@ -18,9 +19,12 @@ function alias_elimination(sys)
1819
s = structure(sys)
1920
@unpack fullvars, graph = s
2021

22+
n_reduced_states = length(v_eliminated) - n_null_vars
23+
reduced_states = similar(v_eliminated, Any, n_reduced_states)
2124
subs = OrderedDict()
22-
if length(v_eliminated) - n_null_vars > 0
23-
for v in v_eliminated[n_null_vars+1:end]
25+
if n_reduced_states > 0
26+
for (i, v) in enumerate(@view v_eliminated[n_null_vars+1:end])
27+
reduced_states[i] = fullvars[v]
2428
subs[fullvars[v]] = iszeroterm(v_types, v) ? 0.0 :
2529
isalias(v_types, v) ? fullvars[alias(v_types, v)] :
2630
-fullvars[negalias(v_types, v)]
@@ -63,6 +67,7 @@ function alias_elimination(sys)
6367

6468
@set! sys.eqs = eqs
6569
@set! sys.states = newstates
70+
@set! sys.reduced_states = [get_reduced_states(sys); reduced_states]
6671
@set! sys.observed = [get_observed(sys); [lhs ~ rhs for (lhs, rhs) in pairs(subs)]]
6772
@set! sys.structure = nothing
6873
return sys

src/systems/diffeqs/odesystem.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ struct ODESystem <: AbstractODESystem
7474
structure: structural information of the system
7575
"""
7676
structure::Any
77+
reduced_states::Vector
7778
end
7879

7980
function ODESystem(
@@ -101,7 +102,7 @@ function ODESystem(
101102
if length(unique(sysnames)) != length(sysnames)
102103
throw(ArgumentError("System names must be unique."))
103104
end
104-
ODESystem(deqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, default_u0, default_p, nothing)
105+
ODESystem(deqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, default_u0, default_p, nothing, [])
105106
end
106107

107108
var_from_nested_derivative(x, i=0) = (missing, missing)
@@ -235,7 +236,7 @@ ODESystem(eq::Equation, args...; kwargs...) = ODESystem([eq], args...; kwargs...
235236
$(SIGNATURES)
236237
237238
Build the observed function assuming the observed equations are all explicit,
238-
i.e. there are no cycles or dependencies.
239+
i.e. there are no cycles.
239240
"""
240241
function build_explicit_observed_function(
241242
sys, syms;
@@ -249,7 +250,14 @@ function build_explicit_observed_function(
249250

250251
obs = observed(sys)
251252
observed_idx = Dict(map(x->x.lhs, obs) .=> 1:length(obs))
252-
output = map(sym->obs[observed_idx[sym]].rhs, syms)
253+
output = similar(syms, Any)
254+
# FIXME: this is a rather rough estimate of dependencies.
255+
maxidx = 0
256+
for (i, s) in enumerate(syms)
257+
idx = observed_idx[s]
258+
idx > maxidx && (maxidx = idx)
259+
output[i] = obs[idx].rhs
260+
end
253261

254262
ex = Func(
255263
[
@@ -258,7 +266,10 @@ function build_explicit_observed_function(
258266
independent_variable(sys)
259267
],
260268
[],
261-
isscalar ? output[1] : MakeArray(output, output_type)
269+
Let(
270+
map(eq -> eq.lhseq.rhs, obs[1:maxidx]),
271+
isscalar ? output[1] : MakeArray(output, output_type)
272+
)
262273
) |> toexpr
263274

264275
expression ? ex : @RuntimeGeneratedFunction(ex)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct NonlinearSystem <: AbstractSystem
4848
structure: structural information of the system
4949
"""
5050
structure::Any
51+
reduced_states::Any
5152
end
5253

5354
function NonlinearSystem(eqs, states, ps;
@@ -60,7 +61,7 @@ function NonlinearSystem(eqs, states, ps;
6061
default_p isa Dict || (default_p = Dict(default_p))
6162
default_u0 = Dict(value(k) => value(default_u0[k]) for k in keys(default_u0))
6263
default_p = Dict(value(k) => value(default_p[k]) for k in keys(default_p))
63-
NonlinearSystem(eqs, value.(states), value.(ps), observed, name, systems, default_u0, default_p, nothing)
64+
NonlinearSystem(eqs, value.(states), value.(ps), observed, name, systems, default_u0, default_p, nothing, [])
6465
end
6566

6667
function calculate_jacobian(sys::NonlinearSystem;sparse=false,simplify=false)

test/components.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,16 @@ sol = solve(prob, Rodas4())
100100
@test sol[capacitor.n.i] == -sol[capacitor.p.i]
101101
@test iszero(sol[ground.g.i])
102102
@test iszero(sol[ground.g.v])
103+
@test sol[resistor.v] == sol[source.p.v] - sol[capacitor.p.v]
103104

104105
prob = ODAEProblem(sys, u0, (0, 10.0))
105-
sol = solve(prob, Rodas4())
106+
sol = solve(prob, Tsit5())
106107

107108
@test sol[resistor.p.i] == sol[capacitor.p.i]
108109
@test sol[resistor.n.i] == -sol[capacitor.p.i]
109110
@test sol[capacitor.n.i] == -sol[capacitor.p.i]
110111
@test iszero(sol[ground.g.i])
111112
@test iszero(sol[ground.g.v])
113+
@test sol[resistor.v] == sol[source.p.v] - sol[capacitor.p.v]
112114
#using Plots
113115
#plot(sol)

0 commit comments

Comments
 (0)