Skip to content

Commit 6b8c7ab

Browse files
authored
Limit the inline nlsolve size (#1367)
1 parent 980e61b commit 6b8c7ab

File tree

1 file changed

+41
-25
lines changed

1 file changed

+41
-25
lines changed

src/structural_transformation/codegen.jl

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
using LinearAlgebra
2+
3+
const MAX_INLINE_NLSOLVE_SIZE = 8
4+
15
function torn_system_jacobian_sparsity(sys)
26
s = structure(sys)
37
@unpack fullvars, graph, partitions = s
@@ -184,41 +188,54 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
184188
]
185189
end
186190

187-
function get_torn_eqs_vars(sys; checkbounds=true)
188-
s = structure(sys)
189-
partitions = s.partitions
190-
vars = s.fullvars
191-
eqs = equations(sys)
192-
193-
torn_eqs = map(idxs-> eqs[idxs], map(x->x.e_residual, partitions))
194-
torn_vars = map(idxs->vars[idxs], map(x->x.v_residual, partitions))
195-
u0map = defaults(sys)
196-
197-
gen_nlsolve.(torn_eqs, torn_vars, (u0map,), checkbounds=checkbounds)
198-
end
199-
200191
function build_torn_function(
201192
sys;
202193
expression=false,
203194
jacobian_sparsity=true,
204195
checkbounds=false,
196+
max_inlining_size=nothing,
205197
kw...
206198
)
207199

200+
max_inlining_size = something(max_inlining_size, MAX_INLINE_NLSOLVE_SIZE)
208201
rhss = []
209-
for eq in equations(sys)
202+
eqs = equations(sys)
203+
for eq in eqs
210204
isdiffeq(eq) && push!(rhss, eq.rhs)
211205
end
212206

207+
s = structure(sys)
208+
@unpack fullvars, partitions = s
209+
210+
states = map(i->s.fullvars[i], diffvars_range(s))
211+
mass_matrix_diag = ones(length(states))
212+
torn_expr = []
213+
defs = defaults(sys)
214+
215+
needs_extending = false
216+
for p in partitions
217+
@unpack e_residual, v_residual = p
218+
torn_eqs = eqs[e_residual]
219+
torn_vars = fullvars[v_residual]
220+
if length(e_residual) <= max_inlining_size
221+
append!(torn_expr, gen_nlsolve(torn_eqs, torn_vars, defs, checkbounds=checkbounds))
222+
else
223+
needs_extending = true
224+
append!(rhss, map(x->x.rhs, torn_eqs))
225+
append!(states, torn_vars)
226+
append!(mass_matrix_diag, zeros(length(torn_eqs)))
227+
end
228+
end
229+
230+
mass_matrix = needs_extending ? Diagonal(mass_matrix_diag) : I
231+
213232
out = Sym{Any}(gensym("out"))
214-
odefunbody = SetArray(
233+
funbody = SetArray(
215234
!checkbounds,
216235
out,
217236
rhss
218237
)
219238

220-
s = structure(sys)
221-
states = map(i->s.fullvars[i], diffvars_range(s))
222239
syms = map(Symbol, states)
223240
pre = get_postprocess_fbody(sys)
224241

@@ -232,13 +249,13 @@ function build_torn_function(
232249
],
233250
[],
234251
pre(Let(
235-
collect(Iterators.flatten(get_torn_eqs_vars(sys, checkbounds=checkbounds))),
236-
odefunbody
252+
torn_expr,
253+
funbody
237254
))
238255
)
239256
)
240257
if expression
241-
expr
258+
expr, states
242259
else
243260
observedfun = let sys = sys, dict = Dict()
244261
function generated_observed(obsvar, u, p, t)
@@ -254,7 +271,8 @@ function build_torn_function(
254271
sparsity = torn_system_jacobian_sparsity(sys),
255272
syms = syms,
256273
observed = observedfun,
257-
)
274+
mass_matrix = mass_matrix,
275+
), states
258276
end
259277
end
260278

@@ -385,14 +403,12 @@ function ODAEProblem{iip}(
385403
parammap=DiffEqBase.NullParameters();
386404
kw...
387405
) where {iip}
388-
s = structure(sys)
389-
@unpack fullvars = s
390-
dvs = map(i->fullvars[i], diffvars_range(s))
406+
fun, dvs = build_torn_function(sys; kw...)
391407
ps = parameters(sys)
392408
defs = defaults(sys)
393409

394410
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults=defs)
395411
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults=defs)
396412

397-
ODEProblem{iip}(build_torn_function(sys; kw...), u0, tspan, p; kw...)
413+
ODEProblem{iip}(fun, u0, tspan, p; kw...)
398414
end

0 commit comments

Comments
 (0)