@@ -4,9 +4,9 @@ using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfindi
4
4
5
5
const MAX_INLINE_NLSOLVE_SIZE = 8
6
6
7
- function torn_system_with_nlsolve_jacobian_sparsity (sys , var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
8
- s = structure (sys)
9
- @unpack fullvars, graph = s
7
+ function torn_system_with_nlsolve_jacobian_sparsity (state , var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
8
+ fullvars = state . fullvars
9
+ graph = state . structure . graph
10
10
11
11
# The sparsity pattern of `nlsolve(f, u, p)` w.r.t `p` is difficult to
12
12
# determine in general. Consider the "simplest" case, a linear system. We
@@ -73,6 +73,7 @@ function torn_system_with_nlsolve_jacobian_sparsity(sys, var_eq_matching, var_sc
73
73
nlsolve_vars_set = BitSet (nlsolve_vars)
74
74
75
75
I = Int[]; J = Int[]
76
+ s = state. structure
76
77
for ieq in 𝑠vertices (graph)
77
78
nieq = get (eqs2idx, ieq, 0 )
78
79
nieq == 0 && continue
@@ -244,15 +245,15 @@ function build_torn_function(
244
245
push! (rhss, eq. rhs)
245
246
end
246
247
247
- s = structure (sys)
248
- @unpack fullvars = s
249
- var_eq_matching, var_sccs = algebraic_variables_scc (sys )
248
+ state = get_or_construct_tearing_state (sys)
249
+ fullvars = state . fullvars
250
+ var_eq_matching, var_sccs = algebraic_variables_scc (state )
250
251
condensed_graph = MatchedCondensationGraph (
251
- DiCMOBiGraph {true} (complete (s . graph), complete (var_eq_matching)), var_sccs)
252
+ DiCMOBiGraph {true} (complete (state . structure . graph), complete (var_eq_matching)), var_sccs)
252
253
toporder = topological_sort_by_dfs (condensed_graph)
253
254
var_sccs = var_sccs[toporder]
254
255
255
- states_idxs = collect (diffvars_range (s ))
256
+ states_idxs = collect (diffvars_range (state . structure ))
256
257
mass_matrix_diag = ones (length (states_idxs))
257
258
258
259
assignments, deps, sol_states = tearing_assignments (sys)
@@ -276,7 +277,7 @@ function build_torn_function(
276
277
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
277
278
isempty (torn_eqs_idxs) && continue
278
279
if length (torn_eqs_idxs) <= max_inlining_size
279
- nlsolve_expr = gen_nlsolve! (is_not_prepended_assignment, eqs[torn_eqs_idxs], s . fullvars[torn_vars_idxs], defs, assignments, (deps, invdeps), var2assignment, checkbounds= checkbounds)
280
+ nlsolve_expr = gen_nlsolve! (is_not_prepended_assignment, eqs[torn_eqs_idxs], fullvars[torn_vars_idxs], defs, assignments, (deps, invdeps), var2assignment, checkbounds= checkbounds)
280
281
append! (torn_expr, nlsolve_expr)
281
282
push! (nlsolve_scc_idxs, i)
282
283
else
@@ -297,7 +298,7 @@ function build_torn_function(
297
298
rhss
298
299
)
299
300
300
- states = s . fullvars[states_idxs]
301
+ states = fullvars[states_idxs]
301
302
syms = map (Symbol, states_idxs)
302
303
303
304
pre = get_postprocess_fbody (sys)
@@ -322,10 +323,10 @@ function build_torn_function(
322
323
if expression
323
324
expr, states
324
325
else
325
- observedfun = let sys = sys , dict= Dict (), assignments= assignments, deps= (deps, invdeps), sol_states= sol_states, var2assignment= var2assignment
326
+ observedfun = let state = state , dict= Dict (), assignments= assignments, deps= (deps, invdeps), sol_states= sol_states, var2assignment= var2assignment
326
327
function generated_observed (obsvar, u, p, t)
327
328
obs = get! (dict, value (obsvar)) do
328
- build_observed_function (sys , obsvar, var_eq_matching, var_sccs,
329
+ build_observed_function (state , obsvar, var_eq_matching, var_sccs,
329
330
assignments, deps, sol_states, var2assignment,
330
331
checkbounds= checkbounds,
331
332
)
@@ -336,7 +337,7 @@ function build_torn_function(
336
337
337
338
ODEFunction {true} (
338
339
@RuntimeGeneratedFunction (expr),
339
- sparsity = jacobian_sparsity ? torn_system_with_nlsolve_jacobian_sparsity (sys , var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing ,
340
+ sparsity = jacobian_sparsity ? torn_system_with_nlsolve_jacobian_sparsity (state , var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing ,
340
341
syms = syms,
341
342
observed = observedfun,
342
343
mass_matrix = mass_matrix,
@@ -362,7 +363,7 @@ function find_solve_sequence(sccs, vars)
362
363
end
363
364
364
365
function build_observed_function (
365
- sys , ts, var_eq_matching, var_sccs,
366
+ state , ts, var_eq_matching, var_sccs,
366
367
assignments,
367
368
deps,
368
369
sol_states,
@@ -379,12 +380,14 @@ function build_observed_function(
379
380
ts = Symbolics. scalarize .(value .(ts))
380
381
381
382
vars = Set ()
383
+ sys = state. sys
382
384
foreach (Base. Fix1 (vars!, vars), ts)
383
385
ivs = independent_variables (sys)
384
386
dep_vars = collect (setdiff (vars, ivs))
385
387
386
- s = structure (sys)
387
- @unpack fullvars, graph = s
388
+ fullvars = state. fullvars
389
+ s = state. structure
390
+ graph = s. graph
388
391
diffvars = map (i-> fullvars[i], diffvars_range (s))
389
392
algvars = map (i-> fullvars[i], algvars_range (s))
390
393
@@ -416,8 +419,13 @@ function build_observed_function(
416
419
if ! isempty (subset)
417
420
eqs = equations (sys)
418
421
419
- torn_eqs = map (i-> map (v-> eqs[var_eq_matching[v]], var_sccs[i]), subset)
420
- torn_vars = map (i-> map (v-> fullvars[v], var_sccs[i]), subset)
422
+ nested_torn_vars_idxs = []
423
+ for iscc in subset
424
+ torn_vars_idxs = Int[var for var in var_sccs[iscc] if var_eq_matching[var] != = unassigned]
425
+ isempty (torn_vars_idxs) || push! (nested_torn_vars_idxs, torn_vars_idxs)
426
+ end
427
+ torn_eqs = [[eqs[var_eq_matching[i]] for i in idxs] for idxs in nested_torn_vars_idxs]
428
+ torn_vars = [fullvars[idxs] for idxs in nested_torn_vars_idxs]
421
429
u0map = defaults (sys)
422
430
assignments = copy (assignments)
423
431
solves = map (zip (torn_eqs, torn_vars)) do (eqs, vars)
0 commit comments