1
+ using LinearAlgebra
2
+
3
+ const MAX_INLINE_NLSOLVE_SIZE = 8
4
+
1
5
function torn_system_jacobian_sparsity (sys)
2
6
s = structure (sys)
3
7
@unpack fullvars, graph, partitions = s
@@ -184,41 +188,54 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
184
188
]
185
189
end
186
190
187
- function get_torn_eqs_vars (sys; checkbounds= true )
188
- s = structure (sys)
189
- partitions = s. partitions
190
- vars = s. fullvars
191
- eqs = equations (sys)
192
-
193
- torn_eqs = map (idxs-> eqs[idxs], map (x-> x. e_residual, partitions))
194
- torn_vars = map (idxs-> vars[idxs], map (x-> x. v_residual, partitions))
195
- u0map = defaults (sys)
196
-
197
- gen_nlsolve .(torn_eqs, torn_vars, (u0map,), checkbounds= checkbounds)
198
- end
199
-
200
191
function build_torn_function (
201
192
sys;
202
193
expression= false ,
203
194
jacobian_sparsity= true ,
204
195
checkbounds= false ,
196
+ max_inlining_size= nothing ,
205
197
kw...
206
198
)
207
199
200
+ max_inlining_size = something (max_inlining_size, MAX_INLINE_NLSOLVE_SIZE)
208
201
rhss = []
209
- for eq in equations (sys)
202
+ eqs = equations (sys)
203
+ for eq in eqs
210
204
isdiffeq (eq) && push! (rhss, eq. rhs)
211
205
end
212
206
207
+ s = structure (sys)
208
+ @unpack fullvars, partitions = s
209
+
210
+ states = map (i-> s. fullvars[i], diffvars_range (s))
211
+ mass_matrix_diag = ones (length (states))
212
+ torn_expr = []
213
+ defs = defaults (sys)
214
+
215
+ needs_extending = false
216
+ for p in partitions
217
+ @unpack e_residual, v_residual = p
218
+ torn_eqs = eqs[e_residual]
219
+ torn_vars = fullvars[v_residual]
220
+ if length (e_residual) <= max_inlining_size
221
+ append! (torn_expr, gen_nlsolve (torn_eqs, torn_vars, defs, checkbounds= checkbounds))
222
+ else
223
+ needs_extending = true
224
+ append! (rhss, map (x-> x. rhs, torn_eqs))
225
+ append! (states, torn_vars)
226
+ append! (mass_matrix_diag, zeros (length (torn_eqs)))
227
+ end
228
+ end
229
+
230
+ mass_matrix = needs_extending ? Diagonal (mass_matrix_diag) : I
231
+
213
232
out = Sym {Any} (gensym (" out" ))
214
- odefunbody = SetArray (
233
+ funbody = SetArray (
215
234
! checkbounds,
216
235
out,
217
236
rhss
218
237
)
219
238
220
- s = structure (sys)
221
- states = map (i-> s. fullvars[i], diffvars_range (s))
222
239
syms = map (Symbol, states)
223
240
pre = get_postprocess_fbody (sys)
224
241
@@ -232,13 +249,13 @@ function build_torn_function(
232
249
],
233
250
[],
234
251
pre (Let (
235
- collect (Iterators . flatten ( get_torn_eqs_vars (sys, checkbounds = checkbounds))) ,
236
- odefunbody
252
+ torn_expr ,
253
+ funbody
237
254
))
238
255
)
239
256
)
240
257
if expression
241
- expr
258
+ expr, states
242
259
else
243
260
observedfun = let sys = sys, dict = Dict ()
244
261
function generated_observed (obsvar, u, p, t)
@@ -254,7 +271,8 @@ function build_torn_function(
254
271
sparsity = torn_system_jacobian_sparsity (sys),
255
272
syms = syms,
256
273
observed = observedfun,
257
- )
274
+ mass_matrix = mass_matrix,
275
+ ), states
258
276
end
259
277
end
260
278
@@ -385,14 +403,12 @@ function ODAEProblem{iip}(
385
403
parammap= DiffEqBase. NullParameters ();
386
404
kw...
387
405
) where {iip}
388
- s = structure (sys)
389
- @unpack fullvars = s
390
- dvs = map (i-> fullvars[i], diffvars_range (s))
406
+ fun, dvs = build_torn_function (sys; kw... )
391
407
ps = parameters (sys)
392
408
defs = defaults (sys)
393
409
394
410
u0 = ModelingToolkit. varmap_to_vars (u0map, dvs; defaults= defs)
395
411
p = ModelingToolkit. varmap_to_vars (parammap, ps; defaults= defs)
396
412
397
- ODEProblem {iip} (build_torn_function (sys; kw ... ) , u0, tspan, p; kw... )
413
+ ODEProblem {iip} (fun , u0, tspan, p; kw... )
398
414
end
0 commit comments