@@ -4,7 +4,7 @@ using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfindi
4
4
5
5
const MAX_INLINE_NLSOLVE_SIZE = 8
6
6
7
- function torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
7
+ function torn_system_with_nlsolve_jacobian_sparsity (sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
8
8
s = structure (sys)
9
9
@unpack fullvars, graph = s
10
10
@@ -95,30 +95,71 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_s
95
95
sparse (I, J, true )
96
96
end
97
97
98
- """
99
- exprs = gen_nlsolve(eqs::Vector{Equation}, vars::Vector, u0map::Dict; checkbounds = true)
100
-
101
- Generate `SymbolicUtils` expressions for a root-finding function based on `eqs`,
102
- as well as a call to the root-finding solver.
103
-
104
- `exprs` is a two element vector
105
- ```
106
- exprs = [fname = f, numerical_nlsolve(fname, ...)]
107
- ```
108
-
109
- # Arguments:
110
- - `eqs`: Equations to find roots of.
111
- - `vars`: ???
112
- - `u0map`: A `Dict` which maps variables in `eqs` to values, e.g., `defaults(sys)` if `eqs = equations(sys)`.
113
- - `checkbounds`: Apply bounds checking in the generated code.
114
- """
115
- function gen_nlsolve (eqs, vars, u0map:: AbstractDict ; checkbounds= true )
98
+ function gen_nlsolve! (is_not_prepended_assignment, eqs, vars, u0map:: AbstractDict , assignments, (deps, invdeps), var2assignment; checkbounds= true )
116
99
isempty (vars) && throw (ArgumentError (" vars may not be empty" ))
117
100
length (eqs) == length (vars) || throw (ArgumentError (" vars must be of the same length as the number of equations to find the roots of" ))
118
101
rhss = map (x-> x. rhs, eqs)
119
102
# We use `vars` instead of `graph` to capture parameters, too.
120
- allvars = unique (collect (Iterators. flatten (map (ModelingToolkit. vars, rhss))))
121
- params = setdiff (allvars, vars) # these are not the subject of the root finding
103
+ paramset = ModelingToolkit. vars (r for r in rhss)
104
+
105
+ # Compute necessary assignments for the nlsolve expr
106
+ init_assignments = [var2assignment[p] for p in paramset if haskey (var2assignment, p)]
107
+ tmp = [init_assignments]
108
+ # `deps[init_assignments]` gives the dependency of `init_assignments`
109
+ while true
110
+ next_assignments = reduce (vcat, deps[init_assignments])
111
+ isempty (next_assignments) && break
112
+ init_assignments = next_assignments
113
+ push! (tmp, init_assignments)
114
+ end
115
+ needed_assignments_idxs = reduce (vcat, unique (reverse (tmp)))
116
+ needed_assignments = assignments[needed_assignments_idxs]
117
+
118
+ # Compute `params`. They are like enclosed variables
119
+ rhsvars = [ModelingToolkit. vars (r. rhs) for r in needed_assignments]
120
+ vars_set = Set (vars)
121
+ outer_set = BitSet ()
122
+ inner_set = BitSet ()
123
+ for (i, vs) in enumerate (rhsvars)
124
+ j = needed_assignments_idxs[i]
125
+ if isdisjoint (vars_set, vs)
126
+ push! (outer_set, j)
127
+ else
128
+ push! (inner_set, j)
129
+ end
130
+ end
131
+ init_refine = BitSet ()
132
+ for i in inner_set
133
+ union! (init_refine, invdeps[i])
134
+ end
135
+ intersect! (init_refine, outer_set)
136
+ setdiff! (outer_set, init_refine)
137
+ union! (inner_set, init_refine)
138
+
139
+ next_refine = BitSet ()
140
+ while true
141
+ for i in init_refine
142
+ id = invdeps[i]
143
+ isempty (id) && break
144
+ union! (next_refine, id)
145
+ end
146
+ intersect! (next_refine, outer_set)
147
+ isempty (next_refine) && break
148
+ setdiff! (outer_set, next_refine)
149
+ union! (inner_set, next_refine)
150
+
151
+ init_refine, next_refine = next_refine, init_refine
152
+ empty! (next_refine)
153
+ end
154
+ global2local = Dict (j=> i for (i, j) in enumerate (needed_assignments_idxs))
155
+ inner_idxs = [global2local[i] for i in collect (inner_set)]
156
+ outer_idxs = [global2local[i] for i in collect (outer_set)]
157
+ extravars = reduce (union!, rhsvars[inner_idxs], init= Set ())
158
+ union! (paramset, extravars)
159
+ setdiff! (paramset, vars)
160
+ setdiff! (paramset, [needed_assignments[i]. lhs for i in inner_idxs])
161
+ union! (paramset, [needed_assignments[i]. lhs for i in outer_idxs])
162
+ params = collect (paramset)
122
163
123
164
# splatting to tighten the type
124
165
u0 = []
@@ -144,7 +185,11 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
144
185
DestructuredArgs (params, inbounds= ! checkbounds)
145
186
],
146
187
[],
147
- isscalar ? rhss[1 ] : MakeArray (rhss, SVector)
188
+ Let (
189
+ needed_assignments[inner_idxs],
190
+ isscalar ? rhss[1 ] : MakeArray (rhss, SVector),
191
+ false
192
+ )
148
193
) |> SymbolicUtils. Code. toexpr
149
194
150
195
# solver call contains code to call the root-finding solver on the function f
@@ -158,10 +203,21 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
158
203
)
159
204
end )
160
205
161
- [
162
- fname ← @RuntimeGeneratedFunction (f)
163
- DestructuredArgs (vars, inbounds= ! checkbounds) ← solver_call
164
- ]
206
+ preassignments = []
207
+ for i in outer_idxs
208
+ ii = needed_assignments_idxs[i]
209
+ is_not_prepended_assignment[ii] || continue
210
+ is_not_prepended_assignment[ii] = false
211
+ push! (preassignments, assignments[ii])
212
+ end
213
+
214
+ nlsolve_expr = Assignment[
215
+ preassignments
216
+ fname ← @RuntimeGeneratedFunction (f)
217
+ DestructuredArgs (vars, inbounds= ! checkbounds) ← solver_call
218
+ ]
219
+
220
+ nlsolve_expr
165
221
end
166
222
167
223
function build_torn_function (
@@ -193,18 +249,30 @@ function build_torn_function(
193
249
194
250
states_idxs = collect (diffvars_range (s))
195
251
mass_matrix_diag = ones (length (states_idxs))
196
- torn_expr = []
252
+
253
+ assignments, deps, sol_states = tearing_assignments (sys)
254
+ invdeps = map (_-> BitSet (), deps)
255
+ for (i, d) in enumerate (deps)
256
+ for a in d
257
+ push! (invdeps[a], i)
258
+ end
259
+ end
260
+ var2assignment = Dict {Any,Int} (eq. lhs => i for (i, eq) in enumerate (assignments))
261
+ is_not_prepended_assignment = trues (length (assignments))
262
+
263
+ torn_expr = Assignment[]
264
+
197
265
defs = defaults (sys)
198
266
nlsolve_scc_idxs = Int[]
199
267
200
268
needs_extending = false
201
- for (i, scc) in enumerate (var_sccs)
202
- # torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
269
+ @views for (i, scc) in enumerate (var_sccs)
203
270
torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] != = unassigned]
204
271
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
205
272
isempty (torn_eqs_idxs) && continue
206
273
if length (torn_eqs_idxs) <= max_inlining_size
207
- append! (torn_expr, gen_nlsolve (eqs[torn_eqs_idxs], s. fullvars[torn_vars_idxs], defs, checkbounds= checkbounds))
274
+ nlsolve_expr = gen_nlsolve! (is_not_prepended_assignment, eqs[torn_eqs_idxs], s. fullvars[torn_vars_idxs], defs, assignments, (deps, invdeps), var2assignment, checkbounds= checkbounds)
275
+ append! (torn_expr, nlsolve_expr)
208
276
push! (nlsolve_scc_idxs, i)
209
277
else
210
278
needs_extending = true
@@ -226,6 +294,7 @@ function build_torn_function(
226
294
227
295
states = s. fullvars[states_idxs]
228
296
syms = map (Symbol, states_idxs)
297
+
229
298
pre = get_postprocess_fbody (sys)
230
299
231
300
expr = SymbolicUtils. Code. toexpr (
@@ -238,26 +307,31 @@ function build_torn_function(
238
307
],
239
308
[],
240
309
pre (Let (
241
- torn_expr,
242
- funbody
310
+ [torn_expr; assignments[is_not_prepended_assignment]],
311
+ funbody,
312
+ false
243
313
))
244
- )
314
+ ),
315
+ sol_states
245
316
)
246
317
if expression
247
318
expr, states
248
319
else
249
- observedfun = let sys = sys, dict = Dict ()
320
+ observedfun = let sys= sys, dict= Dict (), assignments = assignments, deps = (deps, invdeps), sol_states = sol_states, var2assignment = var2assignment
250
321
function generated_observed (obsvar, u, p, t)
251
322
obs = get! (dict, value (obsvar)) do
252
- build_observed_function (sys, obsvar, var_eq_matching, var_sccs, checkbounds= checkbounds)
323
+ build_observed_function (sys, obsvar, var_eq_matching, var_sccs,
324
+ assignments, deps, sol_states, var2assignment,
325
+ checkbounds= checkbounds,
326
+ )
253
327
end
254
328
obs (u, p, t)
255
329
end
256
330
end
257
331
258
332
ODEFunction {true} (
259
333
@RuntimeGeneratedFunction (expr),
260
- sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity (sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing ,
334
+ sparsity = jacobian_sparsity ? torn_system_with_nlsolve_jacobian_sparsity (sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing ,
261
335
syms = syms,
262
336
observed = observedfun,
263
337
mass_matrix = mass_matrix,
@@ -283,12 +357,17 @@ function find_solve_sequence(sccs, vars)
283
357
end
284
358
285
359
function build_observed_function (
286
- sys, ts, var_eq_matching, var_sccs;
360
+ sys, ts, var_eq_matching, var_sccs,
361
+ assignments,
362
+ deps,
363
+ sol_states,
364
+ var2assignment;
287
365
expression= false ,
288
366
output_type= Array,
289
- checkbounds= true
367
+ checkbounds= true ,
290
368
)
291
369
370
+ is_not_prepended_assignment = trues (length (assignments))
292
371
if (isscalar = ! (ts isa AbstractVector))
293
372
ts = [ts]
294
373
end
@@ -335,7 +414,11 @@ function build_observed_function(
335
414
torn_eqs = map (i-> map (v-> eqs[var_eq_matching[v]], var_sccs[i]), subset)
336
415
torn_vars = map (i-> map (v-> fullvars[v], var_sccs[i]), subset)
337
416
u0map = defaults (sys)
338
- solves = gen_nlsolve .(torn_eqs, torn_vars, (u0map,); checkbounds= checkbounds)
417
+ assignments = copy (assignments)
418
+ solves = map (zip (torn_eqs, torn_vars)) do (eqs, vars)
419
+ gen_nlsolve! (is_not_prepended_assignment, eqs, vars,
420
+ u0map, assignments, deps, var2assignment; checkbounds= checkbounds)
421
+ end
339
422
else
340
423
solves = []
341
424
end
@@ -348,7 +431,7 @@ function build_observed_function(
348
431
end
349
432
pre = get_postprocess_fbody (sys)
350
433
351
- ex = Func (
434
+ ex = Code . toexpr ( Func (
352
435
[
353
436
DestructuredArgs (diffvars, inbounds= ! checkbounds)
354
437
DestructuredArgs (parameters (sys), inbounds= ! checkbounds)
@@ -360,10 +443,12 @@ function build_observed_function(
360
443
collect (Iterators. flatten (solves))
361
444
map (eq -> eq. lhs← eq. rhs, obs[1 : maxidx])
362
445
subs
446
+ assignments[is_not_prepended_assignment]
363
447
],
364
- isscalar ? ts[1 ] : MakeArray (ts, output_type)
448
+ isscalar ? ts[1 ] : MakeArray (ts, output_type),
449
+ false
365
450
))
366
- ) |> Code . toexpr
451
+ ), sol_states)
367
452
368
453
expression ? ex : @RuntimeGeneratedFunction (ex)
369
454
end
0 commit comments