@@ -4,7 +4,7 @@ using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfindi
4
4
5
5
const MAX_INLINE_NLSOLVE_SIZE = 8
6
6
7
- function torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs)
7
+ function torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, states_idxs )
8
8
s = structure (sys)
9
9
@unpack fullvars, graph = s
10
10
@@ -42,54 +42,52 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs)
42
42
# from previous partitions. Hence, we can build the dependency chain as we
43
43
# traverse the partitions.
44
44
45
- # `avars2dvars` maps a algebraic variable to its differential variable
46
- # dependencies.
47
- avars2dvars = Dict {Int,Set{Int}} ()
48
- c = 0
49
- for scc in var_sccs
50
- v_residual = scc
51
- e_residual = [var_eq_matching[c] for c in v_residual if var_eq_matching[c] != = unassigned]
52
- # initialization
53
- for tvar in v_residual
54
- avars2dvars[tvar] = Set {Int} ()
45
+ var_rename = ones (Int64, ndsts (graph))
46
+ nlsolve_vars = Int[]
47
+ for i in nlsolve_scc_idxs, c in var_sccs[i]
48
+ append! (nlsolve_vars, c)
49
+ for v in c
50
+ var_rename[v] = 0
55
51
end
56
- for teq in e_residual
57
- c += 1
58
- for var in 𝑠neighbors (graph, teq)
59
- # Skip the tearing variables in the current partition, because
60
- # we are computing them from all the other states.
61
- Graphs. insorted (var, v_residual) && continue
62
- deps = get (avars2dvars, var, nothing )
63
- if deps === nothing # differential variable
64
- @assert ! isalgvar (s, var)
65
- for tvar in v_residual
66
- push! (avars2dvars[tvar], var)
67
- end
68
- else # tearing variable from previous partitions
69
- @assert isalgvar (s, var)
70
- for tvar in v_residual
71
- union! (avars2dvars[tvar], avars2dvars[var])
72
- end
73
- end
52
+ end
53
+ masked_cumsum! (var_rename)
54
+
55
+ dig = DiCMOBiGraph {true} (graph, var_eq_matching)
56
+
57
+ fused_var_deps = map (1 : ndsts (graph)) do v
58
+ BitSet (var_rename[v′] for v′ in neighborhood (dig, v, Inf ; dir= :in ) if var_rename[v′] != 0 )
59
+ end
60
+
61
+ for scc in var_sccs
62
+ if length (scc) >= 2
63
+ deps = fused_var_deps[scc[1 ]]
64
+ for c in 2 : length (scc)
65
+ union! (deps, fused_var_deps[c])
74
66
end
75
67
end
76
68
end
77
69
78
- dvrange = diffvars_range (s)
79
- dvar2idx = Dict (v=> i for (i, v) in enumerate (dvrange))
70
+ nlsolve_eqs = BitSet (var_eq_matching[c]:: Int for c in nlsolve_vars if var_eq_matching[c] != = unassigned)
71
+
72
+ var2idx = Dict (v => i for (i, v) in enumerate (states_idxs))
73
+ nlsolve_vars_set = BitSet (nlsolve_vars)
74
+
80
75
I = Int[]; J = Int[]
81
76
eqidx = 0
82
77
for ieq in 𝑠vertices (graph)
83
- isalgeq (s, ieq) && continue
78
+ ieq in nlsolve_eqs && continue
84
79
eqidx += 1
85
80
for ivar in 𝑠neighbors (graph, ieq)
86
- if isdiffvar (s, ivar)
81
+ isdervar (s, ivar) && continue
82
+ if var_rename[ivar] != 0
87
83
push! (I, eqidx)
88
- push! (J, dvar2idx[ivar])
89
- elseif isalgvar (s, ivar)
90
- for dvar in avars2dvars[ivar]
84
+ push! (J, var2idx[ivar])
85
+ else
86
+ for dvar in fused_var_deps[ivar]
87
+ isdervar (s, dvar) && continue
88
+ dvar in nlsolve_vars_set && continue
91
89
push! (I, eqidx)
92
- push! (J, dvar2idx [dvar])
90
+ push! (J, var2idx [dvar])
93
91
end
94
92
end
95
93
end
@@ -123,7 +121,17 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
123
121
params = setdiff (allvars, vars) # these are not the subject of the root finding
124
122
125
123
# splatting to tighten the type
126
- u0 = [map (var-> get (u0map, var, 1e-3 ), vars)... ]
124
+ u0 = []
125
+ for v in vars
126
+ v in keys (u0map) || (push! (u0, 1e-3 ); continue )
127
+ u = substitute (v, u0map)
128
+ for i in 1 : length (u0map)
129
+ u = substitute (u, u0map)
130
+ u isa Number && (push! (u0, u); break )
131
+ end
132
+ u isa Number || error (" $v doesn't have a default." )
133
+ end
134
+ u0 = [u0... ]
127
135
# specialize on the scalar case
128
136
isscalar = length (u0) == 1
129
137
u0 = isscalar ? u0[1 ] : SVector (u0... )
@@ -175,23 +183,30 @@ function build_torn_function(
175
183
s = structure (sys)
176
184
@unpack fullvars = s
177
185
var_eq_matching, var_sccs = algebraic_variables_scc (sys)
186
+ condensed_graph = MatchedCondensationGraph (
187
+ DiCMOBiGraph {true} (complete (s. graph), complete (var_eq_matching)), var_sccs)
188
+ toporder = topological_sort_by_dfs (condensed_graph)
189
+ var_sccs = var_sccs[toporder]
178
190
179
- states = map (i -> s . fullvars[i], diffvars_range (s))
180
- mass_matrix_diag = ones (length (states ))
191
+ states_idxs = collect ( diffvars_range (s))
192
+ mass_matrix_diag = ones (length (states_idxs ))
181
193
torn_expr = []
182
194
defs = defaults (sys)
195
+ nlsolve_scc_idxs = Int[]
183
196
184
197
needs_extending = false
185
- for scc in var_sccs
186
- torn_vars = [s. fullvars[var] for var in scc if var_eq_matching[var] != = unassigned]
187
- torn_eqs = [eqs[var_eq_matching[var]] for var in scc if var_eq_matching[var] != = unassigned]
198
+ for (i, scc) in enumerate (var_sccs)
199
+ # torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
200
+ torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] != = unassigned]
201
+ torn_eqs = [eqs[var_eq_matching[var]] for var in torn_vars_idxs]
188
202
isempty (torn_eqs) && continue
189
203
if length (torn_eqs) <= max_inlining_size
190
- append! (torn_expr, gen_nlsolve (torn_eqs, torn_vars, defs, checkbounds= checkbounds))
204
+ append! (torn_expr, gen_nlsolve (torn_eqs, s. fullvars[torn_vars_idxs], defs, checkbounds= checkbounds))
205
+ push! (nlsolve_scc_idxs, i)
191
206
else
192
207
needs_extending = true
193
208
append! (rhss, map (x-> x. rhs, torn_eqs))
194
- append! (states, torn_vars )
209
+ append! (states_idxs, torn_vars_idxs )
195
210
append! (mass_matrix_diag, zeros (length (torn_eqs)))
196
211
end
197
212
end
@@ -205,7 +220,8 @@ function build_torn_function(
205
220
rhss
206
221
)
207
222
208
- syms = map (Symbol, states)
223
+ states = s. fullvars[states_idxs]
224
+ syms = map (Symbol, states_idxs)
209
225
pre = get_postprocess_fbody (sys)
210
226
211
227
expr = SymbolicUtils. Code. toexpr (
@@ -237,7 +253,7 @@ function build_torn_function(
237
253
238
254
ODEFunction {true} (
239
255
@RuntimeGeneratedFunction (expr),
240
- sparsity = torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs) ,
256
+ sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, states_idxs) : nothing ,
241
257
syms = syms,
242
258
observed = observedfun,
243
259
mass_matrix = mass_matrix,
0 commit comments