1
1
using LinearAlgebra
2
2
3
+ using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfinding_callback, generate_difference_cb, merge_cb
4
+
3
5
const MAX_INLINE_NLSOLVE_SIZE = 8
4
6
5
- function torn_system_jacobian_sparsity (state, var_eq_matching, var_sccs)
7
+ function torn_system_jacobian_sparsity (state, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs )
6
8
fullvars = state. fullvars
7
9
graph = state. structure. graph
8
10
@@ -40,55 +42,53 @@ function torn_system_jacobian_sparsity(state, var_eq_matching, var_sccs)
40
42
# from previous partitions. Hence, we can build the dependency chain as we
41
43
# traverse the partitions.
42
44
43
- # `avars2dvars` maps a algebraic variable to its differential variable
44
- # dependencies.
45
- avars2dvars = Dict {Int,Set{Int}} ()
46
- c = 0
47
- for scc in var_sccs
48
- v_residual = scc
49
- e_residual = [var_eq_matching[c] for c in v_residual if var_eq_matching[c] != = unassigned]
50
- # initialization
51
- for tvar in v_residual
52
- avars2dvars[tvar] = Set {Int} ()
45
+ var_rename = ones (Int64, ndsts (graph))
46
+ nlsolve_vars = Int[]
47
+ for i in nlsolve_scc_idxs, c in var_sccs[i]
48
+ append! (nlsolve_vars, c)
49
+ for v in c
50
+ var_rename[v] = 0
53
51
end
54
- for teq in e_residual
55
- c += 1
56
- for var in 𝑠neighbors (graph, teq)
57
- # Skip the tearing variables in the current partition, because
58
- # we are computing them from all the other states.
59
- Graphs. insorted (var, v_residual) && continue
60
- deps = get (avars2dvars, var, nothing )
61
- if deps === nothing # differential variable
62
- @assert ! isalgvar (state. structure, var)
63
- for tvar in v_residual
64
- push! (avars2dvars[tvar], var)
65
- end
66
- else # tearing variable from previous partitions
67
- @assert isalgvar (state. structure, var)
68
- for tvar in v_residual
69
- union! (avars2dvars[tvar], avars2dvars[var])
70
- end
71
- end
52
+ end
53
+ masked_cumsum! (var_rename)
54
+
55
+ dig = DiCMOBiGraph {true} (graph, var_eq_matching)
56
+
57
+ fused_var_deps = map (1 : ndsts (graph)) do v
58
+ BitSet (v′ for v′ in neighborhood (dig, v, Inf ; dir= :in ) if var_rename[v′] != 0 )
59
+ end
60
+
61
+ for scc in var_sccs[nlsolve_scc_idxs]
62
+ if length (scc) >= 2
63
+ deps = fused_var_deps[scc[1 ]]
64
+ for c in 2 : length (scc)
65
+ union! (deps, fused_var_deps[c])
66
+ fused_var_deps[c] = deps
72
67
end
73
68
end
74
69
end
75
70
76
- dvrange = diffvars_range (state. structure)
77
- dvar2idx = Dict (v=> i for (i, v) in enumerate (dvrange))
71
+ var2idx = Dict {Int,Int} (v => i for (i, v) in enumerate (states_idxs))
72
+ eqs2idx = Dict {Int,Int} (v => i for (i, v) in enumerate (eqs_idxs))
73
+ nlsolve_vars_set = BitSet (nlsolve_vars)
74
+
78
75
I = Int[]; J = Int[]
79
- eqidx = 0
80
- aeqs = algeqs (state. structure)
76
+ s = state. structure
81
77
for ieq in 𝑠vertices (graph)
82
- ieq in aeqs && continue
83
- eqidx += 1
78
+ nieq = get (eqs2idx, ieq, 0 )
79
+ nieq == 0 && continue
84
80
for ivar in 𝑠neighbors (graph, ieq)
85
- if isdiffvar (state. structure, ivar)
86
- push! (I, eqidx)
87
- push! (J, dvar2idx[ivar])
88
- elseif isalgvar (state. structure, ivar)
89
- for dvar in avars2dvars[ivar]
90
- push! (I, eqidx)
91
- push! (J, dvar2idx[dvar])
81
+ isdervar (s, ivar) && continue
82
+ if var_rename[ivar] != 0
83
+ push! (I, nieq)
84
+ push! (J, var2idx[ivar])
85
+ else
86
+ for dvar in fused_var_deps[ivar]
87
+ isdervar (s, dvar) && continue
88
+ niv = get (var2idx, dvar, 0 )
89
+ niv == 0 && continue
90
+ push! (I, nieq)
91
+ push! (J, niv)
92
92
end
93
93
end
94
94
end
@@ -122,7 +122,17 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
122
122
params = setdiff (allvars, vars) # these are not the subject of the root finding
123
123
124
124
# splatting to tighten the type
125
- u0 = [map (var-> get (u0map, var, 1e-3 ), vars)... ]
125
+ u0 = []
126
+ for v in vars
127
+ v in keys (u0map) || (push! (u0, 1e-3 ); continue )
128
+ u = substitute (v, u0map)
129
+ for i in 1 : length (u0map)
130
+ u = substitute (u, u0map)
131
+ u isa Number && (push! (u0, u); break )
132
+ end
133
+ u isa Number || error (" $v doesn't have a default." )
134
+ end
135
+ u0 = [u0... ]
126
136
# specialize on the scalar case
127
137
isscalar = length (u0) == 1
128
138
u0 = isscalar ? u0[1 ] : SVector (u0... )
@@ -167,8 +177,11 @@ function build_torn_function(
167
177
max_inlining_size = something (max_inlining_size, MAX_INLINE_NLSOLVE_SIZE)
168
178
rhss = []
169
179
eqs = equations (sys)
170
- for eq in eqs
171
- isdiffeq (eq) && push! (rhss, eq. rhs)
180
+ eqs_idxs = Int[]
181
+ for (i, eq) in enumerate (eqs)
182
+ isdiffeq (eq) || continue
183
+ push! (eqs_idxs, i)
184
+ push! (rhss, eq. rhs)
172
185
end
173
186
174
187
state = TearingState (sys)
@@ -179,23 +192,26 @@ function build_torn_function(
179
192
toporder = topological_sort_by_dfs (condensed_graph)
180
193
var_sccs = var_sccs[toporder]
181
194
182
- states = map (i -> fullvars[i], diffvars_range (state. structure))
183
- mass_matrix_diag = ones (length (states ))
195
+ states_idxs = collect ( diffvars_range (state. structure))
196
+ mass_matrix_diag = ones (length (states_idxs ))
184
197
torn_expr = []
185
198
defs = defaults (sys)
199
+ nlsolve_scc_idxs = Int[]
186
200
187
201
needs_extending = false
188
- for scc in var_sccs
189
- torn_vars = [fullvars[var] for var in scc if var_eq_matching[var] != = unassigned]
190
- torn_eqs = [eqs[var_eq_matching[var]] for var in scc if var_eq_matching[var] != = unassigned]
191
- isempty (torn_eqs) && continue
192
- if length (torn_eqs) <= max_inlining_size
193
- append! (torn_expr, gen_nlsolve (torn_eqs, torn_vars, defs, checkbounds= checkbounds))
202
+ for (i, scc) in enumerate (var_sccs)
203
+ torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] != = unassigned]
204
+ torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
205
+ isempty (torn_eqs_idxs) && continue
206
+ if length (torn_eqs_idxs) <= max_inlining_size
207
+ append! (torn_expr, gen_nlsolve (eqs[torn_eqs_idxs], fullvars[torn_vars_idxs], defs, checkbounds= checkbounds))
208
+ push! (nlsolve_scc_idxs, i)
194
209
else
195
210
needs_extending = true
196
- append! (rhss, map (x-> x. rhs, torn_eqs))
197
- append! (states, torn_vars)
198
- append! (mass_matrix_diag, zeros (length (torn_eqs)))
211
+ append! (eqs_idxs, torn_eqs_idxs)
212
+ append! (rhss, map (x-> x. rhs, eqs[torn_eqs_idxs]))
213
+ append! (states_idxs, torn_vars_idxs)
214
+ append! (mass_matrix_diag, zeros (length (torn_eqs_idxs)))
199
215
end
200
216
end
201
217
@@ -208,7 +224,8 @@ function build_torn_function(
208
224
rhss
209
225
)
210
226
211
- syms = map (Symbol, states)
227
+ states = fullvars[states_idxs]
228
+ syms = map (Symbol, states_idxs)
212
229
pre = get_postprocess_fbody (sys)
213
230
214
231
expr = SymbolicUtils. Code. toexpr (
@@ -240,7 +257,7 @@ function build_torn_function(
240
257
241
258
ODEFunction {true} (
242
259
@RuntimeGeneratedFunction (expr),
243
- sparsity = torn_system_jacobian_sparsity (state, var_eq_matching, var_sccs) ,
260
+ sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity (state, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing ,
244
261
syms = syms,
245
262
observed = observedfun,
246
263
mass_matrix = mass_matrix,
@@ -375,14 +392,29 @@ function ODAEProblem{iip}(
375
392
u0map,
376
393
tspan,
377
394
parammap= DiffEqBase. NullParameters ();
378
- kw...
395
+ callback = nothing ,
396
+ kwargs...
379
397
) where {iip}
380
- fun, dvs = build_torn_function (sys; kw ... )
398
+ fun, dvs = build_torn_function (sys; kwargs ... )
381
399
ps = parameters (sys)
382
400
defs = defaults (sys)
383
401
384
402
u0 = ModelingToolkit. varmap_to_vars (u0map, dvs; defaults= defs)
385
403
p = ModelingToolkit. varmap_to_vars (parammap, ps; defaults= defs)
386
404
387
- ODEProblem {iip} (fun, u0, tspan, p; kw... )
405
+ has_difference = any (isdifferenceeq, equations (sys))
406
+ if has_continuous_events (sys)
407
+ event_cb = generate_rootfinding_callback (sys; kwargs... )
408
+ else
409
+ event_cb = nothing
410
+ end
411
+ difference_cb = has_difference ? generate_difference_cb (sys; kwargs... ) : nothing
412
+ cb = merge_cb (event_cb, difference_cb)
413
+ cb = merge_cb (cb, callback)
414
+
415
+ if cb === nothing
416
+ ODEProblem {iip} (fun, u0, tspan, p; kwargs... )
417
+ else
418
+ ODEProblem {iip} (fun, u0, tspan, p; callback= cb, kwargs... )
419
+ end
388
420
end
0 commit comments