@@ -4,7 +4,7 @@ using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfindi
4
4
5
5
const MAX_INLINE_NLSOLVE_SIZE = 8
6
6
7
- function torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs)
7
+ function torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, states_idxs )
8
8
s = structure (sys)
9
9
@unpack fullvars, graph = s
10
10
@@ -42,54 +42,52 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs)
42
42
# from previous partitions. Hence, we can build the dependency chain as we
43
43
# traverse the partitions.
44
44
45
- # `avars2dvars` maps a algebraic variable to its differential variable
46
- # dependencies.
47
- avars2dvars = Dict {Int,Set{Int}} ()
48
- c = 0
49
- for scc in var_sccs
50
- v_residual = scc
51
- e_residual = [var_eq_matching[c] for c in v_residual if var_eq_matching[c] != = unassigned]
52
- # initialization
53
- for tvar in v_residual
54
- avars2dvars[tvar] = Set {Int} ()
45
+ var_rename = ones (Int64, ndsts (graph))
46
+ nlsolve_vars = Int[]
47
+ for i in nlsolve_scc_idxs, c in var_sccs[i]
48
+ append! (nlsolve_vars, c)
49
+ for v in c
50
+ var_rename[v] = 0
55
51
end
56
- for teq in e_residual
57
- c += 1
58
- for var in 𝑠neighbors (graph, teq)
59
- # Skip the tearing variables in the current partition, because
60
- # we are computing them from all the other states.
61
- Graphs. insorted (var, v_residual) && continue
62
- deps = get (avars2dvars, var, nothing )
63
- if deps === nothing # differential variable
64
- @assert ! isalgvar (s, var)
65
- for tvar in v_residual
66
- push! (avars2dvars[tvar], var)
67
- end
68
- else # tearing variable from previous partitions
69
- @assert isalgvar (s, var)
70
- for tvar in v_residual
71
- union! (avars2dvars[tvar], avars2dvars[var])
72
- end
73
- end
52
+ end
53
+ masked_cumsum! (var_rename)
54
+
55
+ dig = DiCMOBiGraph {true} (graph, var_eq_matching)
56
+
57
+ fused_var_deps = map (1 : ndsts (graph)) do v
58
+ BitSet (var_rename[v′] for v′ in neighborhood (dig, v, Inf ; dir= :in ) if var_rename[v′] != 0 )
59
+ end
60
+
61
+ for scc in var_sccs
62
+ if length (scc) >= 2
63
+ deps = fused_var_deps[scc[1 ]]
64
+ for c in 2 : length (scc)
65
+ union! (deps, fused_var_deps[c])
74
66
end
75
67
end
76
68
end
77
69
78
- dvrange = diffvars_range (s)
79
- dvar2idx = Dict (v=> i for (i, v) in enumerate (dvrange))
70
+ nlsolve_eqs = BitSet (var_eq_matching[c]:: Int for c in nlsolve_vars if var_eq_matching[c] != = unassigned)
71
+
72
+ var2idx = Dict (v => i for (i, v) in enumerate (states_idxs))
73
+ nlsolve_vars_set = BitSet (nlsolve_vars)
74
+
80
75
I = Int[]; J = Int[]
81
76
eqidx = 0
82
77
for ieq in 𝑠vertices (graph)
83
- isalgeq (s, ieq) && continue
78
+ ieq in nlsolve_eqs && continue
84
79
eqidx += 1
85
80
for ivar in 𝑠neighbors (graph, ieq)
86
- if isdiffvar (s, ivar)
81
+ isdervar (s, ivar) && continue
82
+ if var_rename[ivar] != 0
87
83
push! (I, eqidx)
88
- push! (J, dvar2idx[ivar])
89
- elseif isalgvar (s, ivar)
90
- for dvar in avars2dvars[ivar]
84
+ push! (J, var2idx[ivar])
85
+ else
86
+ for dvar in fused_var_deps[ivar]
87
+ isdervar (s, dvar) && continue
88
+ dvar in nlsolve_vars_set && continue
91
89
push! (I, eqidx)
92
- push! (J, dvar2idx [dvar])
90
+ push! (J, var2idx [dvar])
93
91
end
94
92
end
95
93
end
@@ -180,22 +178,25 @@ function build_torn_function(
180
178
toporder = topological_sort_by_dfs (condensed_graph)
181
179
var_sccs = var_sccs[toporder]
182
180
183
- states = map (i -> s . fullvars[i], diffvars_range (s))
184
- mass_matrix_diag = ones (length (states ))
181
+ states_idxs = collect ( diffvars_range (s))
182
+ mass_matrix_diag = ones (length (states_idxs ))
185
183
torn_expr = []
186
184
defs = defaults (sys)
185
+ nlsolve_scc_idxs = Int[]
187
186
188
187
needs_extending = false
189
- for scc in var_sccs
190
- torn_vars = [s. fullvars[var] for var in scc if var_eq_matching[var] != = unassigned]
191
- torn_eqs = [eqs[var_eq_matching[var]] for var in scc if var_eq_matching[var] != = unassigned]
188
+ for (i, scc) in enumerate (var_sccs)
189
+ # torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
190
+ torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] != = unassigned]
191
+ torn_eqs = [eqs[var_eq_matching[var]] for var in torn_vars_idxs]
192
192
isempty (torn_eqs) && continue
193
193
if length (torn_eqs) <= max_inlining_size
194
- append! (torn_expr, gen_nlsolve (torn_eqs, torn_vars, defs, checkbounds= checkbounds))
194
+ append! (torn_expr, gen_nlsolve (torn_eqs, s. fullvars[torn_vars_idxs], defs, checkbounds= checkbounds))
195
+ push! (nlsolve_scc_idxs, i)
195
196
else
196
197
needs_extending = true
197
198
append! (rhss, map (x-> x. rhs, torn_eqs))
198
- append! (states, torn_vars )
199
+ append! (states_idxs, torn_vars_idxs )
199
200
append! (mass_matrix_diag, zeros (length (torn_eqs)))
200
201
end
201
202
end
@@ -209,7 +210,8 @@ function build_torn_function(
209
210
rhss
210
211
)
211
212
212
- syms = map (Symbol, states)
213
+ states = s. fullvars[states_idxs]
214
+ syms = map (Symbol, states_idxs)
213
215
pre = get_postprocess_fbody (sys)
214
216
215
217
expr = SymbolicUtils. Code. toexpr (
@@ -241,7 +243,7 @@ function build_torn_function(
241
243
242
244
ODEFunction {true} (
243
245
@RuntimeGeneratedFunction (expr),
244
- sparsity = torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs) ,
246
+ sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, states_idxs) : nothing ,
245
247
syms = syms,
246
248
observed = observedfun,
247
249
mass_matrix = mass_matrix,
0 commit comments