@@ -96,7 +96,7 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_s
96
96
end
97
97
98
98
"""
99
- exprs = gen_nlsolve(eqs::Vector{Equation}, vars::Vector, u0map::Dict; checkbounds = true)
99
+ exprs = gen_nlsolve(eqs::Vector{Equation}, vars::Vector, u0map::Dict; checkbounds = true, assignments )
100
100
101
101
Generate `SymbolicUtils` expressions for a root-finding function based on `eqs`,
102
102
as well as a call to the root-finding solver.
@@ -112,13 +112,24 @@ exprs = [fname = f, numerical_nlsolve(fname, ...)]
112
112
- `u0map`: A `Dict` which maps variables in `eqs` to values, e.g., `defaults(sys)` if `eqs = equations(sys)`.
113
113
- `checkbounds`: Apply bounds checking in the generated code.
114
114
"""
115
- function gen_nlsolve (eqs, vars, u0map:: AbstractDict ; checkbounds= true )
115
+ function gen_nlsolve (eqs, vars, u0map:: AbstractDict ; checkbounds= true , assignments )
116
116
isempty (vars) && throw (ArgumentError (" vars may not be empty" ))
117
117
length (eqs) == length (vars) || throw (ArgumentError (" vars must be of the same length as the number of equations to find the roots of" ))
118
118
rhss = map (x-> x. rhs, eqs)
119
119
# We use `vars` instead of `graph` to capture parameters, too.
120
- allvars = unique (collect (Iterators. flatten (map (ModelingToolkit. vars, rhss))))
121
- params = setdiff (allvars, vars) # these are not the subject of the root finding
120
+ allvars = Set (Iterators. flatten (ModelingToolkit. vars (r) for r in rhss))
121
+ vars_set = Set (vars)
122
+ params = setdiff (allvars, vars_set) # these are not the subject of the root finding
123
+ # needed_assignments = filter(a->a.lhs in params, assignments)
124
+ needed_assignments = assignments[1 : findlast (a-> a. lhs in params, assignments)]
125
+ params = setdiff (params, [a. lhs for a in needed_assignments])
126
+ @show needed_assignments, params
127
+ for a in needed_assignments
128
+ ModelingToolkit. vars! (params, a. rhs)
129
+ end
130
+ params = setdiff (params, vars_set) # these are not the subject of the root finding
131
+ @show params
132
+ # inductor1₊v, inductor2₊v
122
133
123
134
# splatting to tighten the type
124
135
u0 = []
@@ -141,10 +152,10 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
141
152
f = Func (
142
153
[
143
154
DestructuredArgs (vars, inbounds= ! checkbounds)
144
- DestructuredArgs (params, inbounds= ! checkbounds)
155
+ DestructuredArgs (collect ( params) , inbounds= ! checkbounds)
145
156
],
146
157
[],
147
- isscalar ? rhss[1 ] : MakeArray (rhss, SVector)
158
+ Let (needed_assignments, isscalar ? rhss[1 ] : MakeArray (rhss, SVector) )
148
159
) |> SymbolicUtils. Code. toexpr
149
160
150
161
# solver call contains code to call the root-finding solver on the function f
@@ -193,18 +204,21 @@ function build_torn_function(
193
204
194
205
states_idxs = collect (diffvars_range (s))
195
206
mass_matrix_diag = ones (length (states_idxs))
207
+
208
+ assignments, bf_states = tearing_assignments (sys)
196
209
torn_expr = []
210
+
197
211
defs = defaults (sys)
198
212
nlsolve_scc_idxs = Int[]
199
213
200
214
needs_extending = false
201
- for (i, scc) in enumerate (var_sccs)
215
+ @views for (i, scc) in enumerate (var_sccs)
202
216
# torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
203
217
torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] != = unassigned]
204
218
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
205
219
isempty (torn_eqs_idxs) && continue
206
220
if length (torn_eqs_idxs) <= max_inlining_size
207
- append! (torn_expr, gen_nlsolve (eqs[torn_eqs_idxs], s. fullvars[torn_vars_idxs], defs, checkbounds= checkbounds))
221
+ append! (torn_expr, gen_nlsolve (eqs[torn_eqs_idxs], s. fullvars[torn_vars_idxs], defs, checkbounds= checkbounds, assignments = assignments ))
208
222
push! (nlsolve_scc_idxs, i)
209
223
else
210
224
needs_extending = true
@@ -226,6 +240,7 @@ function build_torn_function(
226
240
227
241
states = s. fullvars[states_idxs]
228
242
syms = map (Symbol, states_idxs)
243
+
229
244
pre = get_postprocess_fbody (sys)
230
245
231
246
expr = SymbolicUtils. Code. toexpr (
@@ -238,18 +253,22 @@ function build_torn_function(
238
253
],
239
254
[],
240
255
pre (Let (
241
- torn_expr,
256
+ [ torn_expr; assignments] ,
242
257
funbody
243
258
))
244
- )
259
+ ),
260
+ bf_states
245
261
)
246
262
if expression
247
263
expr, states
248
264
else
249
- observedfun = let sys = sys, dict = Dict ()
265
+ observedfun = let sys= sys, dict= Dict (), assignments = assignments, bf_states = bf_states
250
266
function generated_observed (obsvar, u, p, t)
251
267
obs = get! (dict, value (obsvar)) do
252
- build_observed_function (sys, obsvar, var_eq_matching, var_sccs, checkbounds= checkbounds)
268
+ build_observed_function (sys, obsvar, var_eq_matching, var_sccs,
269
+ checkbounds= checkbounds,
270
+ assignments= assignments,
271
+ bf_states= bf_states)
253
272
end
254
273
obs (u, p, t)
255
274
end
@@ -286,7 +305,9 @@ function build_observed_function(
286
305
sys, ts, var_eq_matching, var_sccs;
287
306
expression= false ,
288
307
output_type= Array,
289
- checkbounds= true
308
+ checkbounds= true ,
309
+ assignments,
310
+ bf_states,
290
311
)
291
312
292
313
if (isscalar = ! (ts isa AbstractVector))
@@ -348,7 +369,7 @@ function build_observed_function(
348
369
end
349
370
pre = get_postprocess_fbody (sys)
350
371
351
- ex = Func (
372
+ ex = Code . toexpr ( Func (
352
373
[
353
374
DestructuredArgs (diffvars, inbounds= ! checkbounds)
354
375
DestructuredArgs (parameters (sys), inbounds= ! checkbounds)
@@ -357,13 +378,14 @@ function build_observed_function(
357
378
[],
358
379
pre (Let (
359
380
[
381
+ assignments
360
382
collect (Iterators. flatten (solves))
361
383
map (eq -> eq. lhs← eq. rhs, obs[1 : maxidx])
362
384
subs
363
385
],
364
386
isscalar ? ts[1 ] : MakeArray (ts, output_type)
365
387
))
366
- ) |> Code . toexpr
388
+ ), bf_states)
367
389
368
390
expression ? ex : @RuntimeGeneratedFunction (ex)
369
391
end
0 commit comments