Skip to content

Commit a31fc03

Browse files
authored
Merge pull request #1489 from SciML/myb/obs
Make observed function building resilient wrt namespacing
2 parents dd6bbd3 + fb37e35 commit a31fc03

File tree

4 files changed

+62
-18
lines changed

4 files changed

+62
-18
lines changed

src/structural_transformation/codegen.jl

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -393,20 +393,40 @@ function build_observed_function(
393393

394394
required_algvars = Set(intersect(algvars, vars))
395395
obs = observed(sys)
396-
observed_idx = Dict(map(x->x.lhs, obs) .=> 1:length(obs))
397-
# FIXME: this is a rather rough estimate of dependencies.
398-
maxidx = 0
396+
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
397+
namespaced_to_obs = Dict(states(sys, x.lhs) => x.lhs for x in obs)
398+
namespaced_to_sts = Dict(states(sys, x) => x for x in states(sys))
399399
sts = Set(states(sys))
400+
401+
# FIXME: This is a rather rough estimate of dependencies. We assume
402+
# the expression depends on everything before the `maxidx`.
403+
subs = Dict()
404+
maxidx = 0
400405
for (i, s) in enumerate(dep_vars)
401406
idx = get(observed_idx, s, nothing)
402-
if idx === nothing
403-
if !(s in sts)
407+
if idx !== nothing
408+
idx > maxidx && (maxidx = idx)
409+
else
410+
s′ = get(namespaced_to_obs, s, nothing)
411+
if s′ !== nothing
412+
subs[s] = s′
413+
s = s′
414+
idx = get(observed_idx, s, nothing)
415+
end
416+
if idx !== nothing
417+
idx > maxidx && (maxidx = idx)
418+
elseif !(s in sts)
419+
s′ = get(namespaced_to_sts, s, nothing)
420+
if s′ !== nothing
421+
subs[s] = s′
422+
continue
423+
end
404424
throw(ArgumentError("$s is either an observed nor a state variable."))
405425
end
406426
continue
407427
end
408-
idx > maxidx && (maxidx = idx)
409428
end
429+
ts = map(t->substitute(t, subs), ts)
410430
vs = Set()
411431
for idx in 1:maxidx
412432
vars!(vs, obs[idx].rhs)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
357357
end
358358

359359
M = calculate_massmatrix(sys)
360-
360+
361361
_M = if sparse && !(u0 === nothing || M === I)
362362
SparseArrays.sparse(M)
363363
elseif u0 === nothing || M === I

src/systems/diffeqs/odesystem.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,21 +279,39 @@ function build_explicit_observed_function(
279279

280280
obs = observed(sys)
281281
sts = Set(states(sys))
282-
observed_idx = Dict(map(x->x.lhs, obs) .=> 1:length(obs))
282+
observed_idx = Dict(x.lhs => i for (i, x) in enumerate(obs))
283+
namespaced_to_obs = Dict(states(sys, x.lhs) => x.lhs for x in obs)
284+
namespaced_to_sts = Dict(states(sys, x) => x for x in states(sys))
283285

284286
# FIXME: This is a rather rough estimate of dependencies. We assume
285287
# the expression depends on everything before the `maxidx`.
288+
subs = Dict()
286289
maxidx = 0
287290
for (i, s) in enumerate(dep_vars)
288291
idx = get(observed_idx, s, nothing)
289-
if idx === nothing
290-
if !(s in sts)
292+
if idx !== nothing
293+
idx > maxidx && (maxidx = idx)
294+
else
295+
s′ = get(namespaced_to_obs, s, nothing)
296+
if s′ !== nothing
297+
subs[s] = s′
298+
s = s′
299+
idx = get(observed_idx, s, nothing)
300+
end
301+
if idx !== nothing
302+
idx > maxidx && (maxidx = idx)
303+
elseif !(s in sts)
304+
s′ = get(namespaced_to_sts, s, nothing)
305+
if s′ !== nothing
306+
subs[s] = s′
307+
continue
308+
end
291309
throw(ArgumentError("$s is either an observed nor a state variable."))
292310
end
293311
continue
294312
end
295-
idx > maxidx && (maxidx = idx)
296313
end
314+
ts = map(t->substitute(t, subs), ts)
297315
obsexprs = map(eq -> eq.lhseq.rhs, obs[1:maxidx])
298316

299317
dvs = DestructuredArgs(states(sys), inbounds=!checkbounds)

test/components.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ function check_contract(sys)
1919
end
2020
end
2121

22+
function check_rc_sol(sol)
23+
@test sol[rc_model.resistor.p.i] == sol[resistor.p.i] == sol[capacitor.p.i]
24+
@test sol[rc_model.resistor.n.i] == sol[resistor.n.i] == -sol[capacitor.p.i]
25+
@test sol[rc_model.capacitor.n.i] ==sol[capacitor.n.i] == -sol[capacitor.p.i]
26+
@test iszero(sol[rc_model.ground.g.i])
27+
@test iszero(sol[rc_model.ground.g.v])
28+
@test sol[rc_model.resistor.v] == sol[resistor.v] == sol[source.p.v] - sol[capacitor.p.v]
29+
end
30+
2231
include("../examples/rc_model.jl")
2332

2433
sys = structural_simplify(rc_model)
@@ -31,13 +40,10 @@ u0 = [
3140
]
3241
prob = ODEProblem(sys, u0, (0, 10.0))
3342
sol = solve(prob, Rodas4())
34-
35-
@test sol[resistor.p.i] == sol[capacitor.p.i]
36-
@test sol[resistor.n.i] == -sol[capacitor.p.i]
37-
@test sol[capacitor.n.i] == -sol[capacitor.p.i]
38-
@test iszero(sol[ground.g.i])
39-
@test iszero(sol[ground.g.v])
40-
@test sol[resistor.v] == sol[source.p.v] - sol[capacitor.p.v]
43+
check_rc_sol(sol)
44+
prob = ODAEProblem(sys, u0, (0, 10.0))
45+
sol = solve(prob, Rodas4())
46+
check_rc_sol(sol)
4147

4248
# Outer/inner connections
4349
function rc_component(;name)

0 commit comments

Comments
 (0)