@@ -288,6 +288,7 @@ function build_torn_function(
288
288
append! (mass_matrix_diag, zeros (length (torn_eqs_idxs)))
289
289
end
290
290
end
291
+ sort! (states_idxs)
291
292
292
293
mass_matrix = needs_extending ? Diagonal (mass_matrix_diag) : I
293
294
@@ -323,11 +324,18 @@ function build_torn_function(
323
324
if expression
324
325
expr, states
325
326
else
326
- observedfun = let state = state, dict= Dict (), assignments= assignments, deps= (deps, invdeps), sol_states= sol_states, var2assignment= var2assignment
327
+ observedfun = let state= state,
328
+ dict= Dict (),
329
+ is_solver_state_idxs= insorted .(1 : length (fullvars), (states_idxs,)),
330
+ assignments= assignments,
331
+ deps= (deps, invdeps),
332
+ sol_states= sol_states,
333
+ var2assignment= var2assignment
334
+
327
335
function generated_observed (obsvar, u, p, t)
328
336
obs = get! (dict, value (obsvar)) do
329
337
build_observed_function (state, obsvar, var_eq_matching, var_sccs,
330
- assignments, deps, sol_states, var2assignment,
338
+ is_solver_state_idxs, assignments, deps, sol_states, var2assignment,
331
339
checkbounds= checkbounds,
332
340
)
333
341
end
364
372
365
373
function build_observed_function (
366
374
state, ts, var_eq_matching, var_sccs,
375
+ is_solver_state_idxs,
367
376
assignments,
368
377
deps,
369
378
sol_states,
@@ -388,8 +397,8 @@ function build_observed_function(
388
397
fullvars = state. fullvars
389
398
s = state. structure
390
399
graph = s. graph
391
- diffvars = map (i -> fullvars[i], diffvars_range (s))
392
- algvars = map (i -> fullvars[i], algvars_range (s))
400
+ solver_states = fullvars[is_solver_state_idxs]
401
+ algvars = fullvars[. ! is_solver_state_idxs]
393
402
394
403
required_algvars = Set (intersect (algvars, vars))
395
404
obs = observed (sys)
@@ -433,6 +442,11 @@ function build_observed_function(
433
442
union! (required_algvars, intersect (algvars, vs))
434
443
empty! (vs)
435
444
end
445
+ for eq in assignments
446
+ vars! (vs, eq. rhs)
447
+ union! (required_algvars, intersect (algvars, vs))
448
+ empty! (vs)
449
+ end
436
450
437
451
varidxs = findall (x-> x in required_algvars, fullvars)
438
452
subset = find_solve_sequence (var_sccs, varidxs)
@@ -466,15 +480,15 @@ function build_observed_function(
466
480
467
481
ex = Code. toexpr (Func (
468
482
[
469
- DestructuredArgs (diffvars , inbounds= ! checkbounds)
483
+ DestructuredArgs (solver_states , inbounds= ! checkbounds)
470
484
DestructuredArgs (parameters (sys), inbounds= ! checkbounds)
471
485
independent_variables (sys)
472
486
],
473
487
[],
474
488
pre (Let (
475
489
[
476
- assignments[is_not_prepended_assignment]
477
490
collect (Iterators. flatten (solves))
491
+ assignments[is_not_prepended_assignment]
478
492
map (eq -> eq. lhs← eq. rhs, obs[1 : maxidx])
479
493
subs
480
494
],
0 commit comments