@@ -4,7 +4,7 @@ using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfindi
4
4
5
5
const MAX_INLINE_NLSOLVE_SIZE = 8
6
6
7
- function torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, states_idxs)
7
+ function torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
8
8
s = structure (sys)
9
9
@unpack fullvars, graph = s
10
10
@@ -55,39 +55,39 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_s
55
55
dig = DiCMOBiGraph {true} (graph, var_eq_matching)
56
56
57
57
fused_var_deps = map (1 : ndsts (graph)) do v
58
- BitSet (var_rename[v′] for v′ in neighborhood (dig, v, Inf ; dir= :in ) if var_rename[v′] != 0 )
58
+ BitSet (v′ for v′ in neighborhood (dig, v, Inf ; dir= :in ) if var_rename[v′] != 0 )
59
59
end
60
60
61
- for scc in var_sccs
61
+ for scc in var_sccs[nlsolve_scc_idxs]
62
62
if length (scc) >= 2
63
63
deps = fused_var_deps[scc[1 ]]
64
64
for c in 2 : length (scc)
65
65
union! (deps, fused_var_deps[c])
66
+ fused_var_deps[c] = deps
66
67
end
67
68
end
68
69
end
69
70
70
- nlsolve_eqs = BitSet (var_eq_matching[c]:: Int for c in nlsolve_vars if var_eq_matching[c] != = unassigned)
71
-
72
- var2idx = Dict (v => i for (i, v) in enumerate (states_idxs))
71
+ var2idx = Dict {Int,Int} (v => i for (i, v) in enumerate (states_idxs))
72
+ eqs2idx = Dict {Int,Int} (v => i for (i, v) in enumerate (eqs_idxs))
73
73
nlsolve_vars_set = BitSet (nlsolve_vars)
74
74
75
75
I = Int[]; J = Int[]
76
- eqidx = 0
77
76
for ieq in 𝑠vertices (graph)
78
- ieq in nlsolve_eqs && continue
79
- eqidx += 1
77
+ nieq = get (eqs2idx, ieq, 0 )
78
+ nieq == 0 && continue
80
79
for ivar in 𝑠neighbors (graph, ieq)
81
80
isdervar (s, ivar) && continue
82
81
if var_rename[ivar] != 0
83
- push! (I, eqidx )
82
+ push! (I, nieq )
84
83
push! (J, var2idx[ivar])
85
84
else
86
85
for dvar in fused_var_deps[ivar]
87
86
isdervar (s, dvar) && continue
88
- dvar in nlsolve_vars_set && continue
89
- push! (I, eqidx)
90
- push! (J, var2idx[dvar])
87
+ niv = get (var2idx, dvar, 0 )
88
+ niv == 0 && continue
89
+ push! (I, nieq)
90
+ push! (J, niv)
91
91
end
92
92
end
93
93
end
@@ -176,8 +176,11 @@ function build_torn_function(
176
176
max_inlining_size = something (max_inlining_size, MAX_INLINE_NLSOLVE_SIZE)
177
177
rhss = []
178
178
eqs = equations (sys)
179
- for eq in eqs
180
- isdiffeq (eq) && push! (rhss, eq. rhs)
179
+ eqs_idxs = Int[]
180
+ for (i, eq) in enumerate (eqs)
181
+ isdiffeq (eq) || continue
182
+ push! (eqs_idxs, i)
183
+ push! (rhss, eq. rhs)
181
184
end
182
185
183
186
s = structure (sys)
@@ -198,16 +201,17 @@ function build_torn_function(
198
201
for (i, scc) in enumerate (var_sccs)
199
202
# torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
200
203
torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] != = unassigned]
201
- torn_eqs = [eqs[ var_eq_matching[var] ] for var in torn_vars_idxs]
202
- isempty (torn_eqs ) && continue
203
- if length (torn_eqs ) <= max_inlining_size
204
- append! (torn_expr, gen_nlsolve (torn_eqs , s. fullvars[torn_vars_idxs], defs, checkbounds= checkbounds))
204
+ torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
205
+ isempty (torn_eqs_idxs ) && continue
206
+ if length (torn_eqs_idxs ) <= max_inlining_size
207
+ append! (torn_expr, gen_nlsolve (eqs[torn_eqs_idxs] , s. fullvars[torn_vars_idxs], defs, checkbounds= checkbounds))
205
208
push! (nlsolve_scc_idxs, i)
206
209
else
207
210
needs_extending = true
208
- append! (rhss, map (x-> x. rhs, torn_eqs))
211
+ append! (eqs_idxs, torn_eqs_idxs)
212
+ append! (rhss, map (x-> x. rhs, eqs[torn_eqs_idxs]))
209
213
append! (states_idxs, torn_vars_idxs)
210
- append! (mass_matrix_diag, zeros (length (torn_eqs )))
214
+ append! (mass_matrix_diag, zeros (length (torn_eqs_idxs )))
211
215
end
212
216
end
213
217
@@ -253,7 +257,7 @@ function build_torn_function(
253
257
254
258
ODEFunction {true} (
255
259
@RuntimeGeneratedFunction (expr),
256
- sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, states_idxs) : nothing ,
260
+ sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing ,
257
261
syms = syms,
258
262
observed = observedfun,
259
263
mass_matrix = mass_matrix,
0 commit comments