Skip to content

Commit 0ee8a7c

Browse files
committed
More precise solver state detection in ODAEProblem
1 parent bcd6736 commit 0ee8a7c

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/structural_transformation/codegen.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ function build_torn_function(
288288
append!(mass_matrix_diag, zeros(length(torn_eqs_idxs)))
289289
end
290290
end
291+
sort!(states_idxs)
291292

292293
mass_matrix = needs_extending ? Diagonal(mass_matrix_diag) : I
293294

@@ -323,11 +324,18 @@ function build_torn_function(
323324
if expression
324325
expr, states
325326
else
326-
observedfun = let state = state, dict=Dict(), assignments=assignments, deps=(deps, invdeps), sol_states=sol_states, var2assignment=var2assignment
327+
observedfun = let state=state,
328+
dict=Dict(),
329+
is_solver_state_idxs=insorted.(1:length(fullvars), (states_idxs,)),
330+
assignments=assignments,
331+
deps=(deps, invdeps),
332+
sol_states=sol_states,
333+
var2assignment=var2assignment
334+
327335
function generated_observed(obsvar, u, p, t)
328336
obs = get!(dict, value(obsvar)) do
329337
build_observed_function(state, obsvar, var_eq_matching, var_sccs,
330-
assignments, deps, sol_states, var2assignment,
338+
is_solver_state_idxs, assignments, deps, sol_states, var2assignment,
331339
checkbounds=checkbounds,
332340
)
333341
end
@@ -364,6 +372,7 @@ end
364372

365373
function build_observed_function(
366374
state, ts, var_eq_matching, var_sccs,
375+
is_solver_state_idxs,
367376
assignments,
368377
deps,
369378
sol_states,
@@ -388,8 +397,8 @@ function build_observed_function(
388397
fullvars = state.fullvars
389398
s = state.structure
390399
graph = s.graph
391-
diffvars = map(i->fullvars[i], diffvars_range(s))
392-
algvars = map(i->fullvars[i], algvars_range(s))
400+
solver_states = fullvars[is_solver_state_idxs]
401+
algvars = fullvars[.!is_solver_state_idxs]
393402

394403
required_algvars = Set(intersect(algvars, vars))
395404
obs = observed(sys)
@@ -471,7 +480,7 @@ function build_observed_function(
471480

472481
ex = Code.toexpr(Func(
473482
[
474-
DestructuredArgs(diffvars, inbounds=!checkbounds)
483+
DestructuredArgs(solver_states, inbounds=!checkbounds)
475484
DestructuredArgs(parameters(sys), inbounds=!checkbounds)
476485
independent_variables(sys)
477486
],

0 commit comments

Comments
 (0)