Skip to content

Commit 8e95263

Browse files
authored
Merge pull request #1499 from SciML/myb/fixobs
More robust observed equations building
2 parents cfc1b4d + 0ee8a7c commit 8e95263

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

src/structural_transformation/codegen.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ function build_torn_function(
288288
append!(mass_matrix_diag, zeros(length(torn_eqs_idxs)))
289289
end
290290
end
291+
sort!(states_idxs)
291292

292293
mass_matrix = needs_extending ? Diagonal(mass_matrix_diag) : I
293294

@@ -323,11 +324,18 @@ function build_torn_function(
323324
if expression
324325
expr, states
325326
else
326-
observedfun = let state = state, dict=Dict(), assignments=assignments, deps=(deps, invdeps), sol_states=sol_states, var2assignment=var2assignment
327+
observedfun = let state=state,
328+
dict=Dict(),
329+
is_solver_state_idxs=insorted.(1:length(fullvars), (states_idxs,)),
330+
assignments=assignments,
331+
deps=(deps, invdeps),
332+
sol_states=sol_states,
333+
var2assignment=var2assignment
334+
327335
function generated_observed(obsvar, u, p, t)
328336
obs = get!(dict, value(obsvar)) do
329337
build_observed_function(state, obsvar, var_eq_matching, var_sccs,
330-
assignments, deps, sol_states, var2assignment,
338+
is_solver_state_idxs, assignments, deps, sol_states, var2assignment,
331339
checkbounds=checkbounds,
332340
)
333341
end
@@ -364,6 +372,7 @@ end
364372

365373
function build_observed_function(
366374
state, ts, var_eq_matching, var_sccs,
375+
is_solver_state_idxs,
367376
assignments,
368377
deps,
369378
sol_states,
@@ -388,8 +397,8 @@ function build_observed_function(
388397
fullvars = state.fullvars
389398
s = state.structure
390399
graph = s.graph
391-
diffvars = map(i->fullvars[i], diffvars_range(s))
392-
algvars = map(i->fullvars[i], algvars_range(s))
400+
solver_states = fullvars[is_solver_state_idxs]
401+
algvars = fullvars[.!is_solver_state_idxs]
393402

394403
required_algvars = Set(intersect(algvars, vars))
395404
obs = observed(sys)
@@ -433,6 +442,11 @@ function build_observed_function(
433442
union!(required_algvars, intersect(algvars, vs))
434443
empty!(vs)
435444
end
445+
for eq in assignments
446+
vars!(vs, eq.rhs)
447+
union!(required_algvars, intersect(algvars, vs))
448+
empty!(vs)
449+
end
436450

437451
varidxs = findall(x->x in required_algvars, fullvars)
438452
subset = find_solve_sequence(var_sccs, varidxs)
@@ -466,15 +480,15 @@ function build_observed_function(
466480

467481
ex = Code.toexpr(Func(
468482
[
469-
DestructuredArgs(diffvars, inbounds=!checkbounds)
483+
DestructuredArgs(solver_states, inbounds=!checkbounds)
470484
DestructuredArgs(parameters(sys), inbounds=!checkbounds)
471485
independent_variables(sys)
472486
],
473487
[],
474488
pre(Let(
475489
[
476-
assignments[is_not_prepended_assignment]
477490
collect(Iterators.flatten(solves))
491+
assignments[is_not_prepended_assignment]
478492
map(eq -> eq.lhseq.rhs, obs[1:maxidx])
479493
subs
480494
],

test/components.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,26 @@ prob = ODAEProblem(sys, u0, (0, 10.0))
4545
sol = solve(prob, Rodas4())
4646
check_rc_sol(sol)
4747

48+
let
49+
# 1478
50+
@named resistor2 = Resistor(R=R)
51+
rc_eqs2 = [
52+
connect(source.p, resistor.p)
53+
connect(resistor.n, resistor2.p)
54+
connect(resistor2.n, capacitor.p)
55+
connect(capacitor.n, source.n)
56+
connect(capacitor.n, ground.g)
57+
]
58+
59+
@named _rc_model2 = ODESystem(rc_eqs2, t)
60+
@named rc_model2 = compose(_rc_model2,
61+
[resistor, resistor2, capacitor, source, ground])
62+
sys2 = structural_simplify(rc_model2)
63+
prob2 = ODAEProblem(sys2, u0, (0, 10.0))
64+
sol2 = solve(prob2, Tsit5())
65+
@test sol2[source.p.i] == sol2[rc_model2.source.p.i] == -sol2[capacitor.i]
66+
end
67+
4868
# Outer/inner connections
4969
function rc_component(;name)
5070
R = 1

0 commit comments

Comments
 (0)