Skip to content

Commit 71bfd7a

Browse files
feat: pre-compute observed equations of previous SCCs
1 parent ba4710b commit 71bfd7a

File tree

1 file changed

+34
-19
lines changed

1 file changed

+34
-19
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -556,18 +556,11 @@ end
556556
struct SCCNonlinearFunction{iip} end
557557

558558
function SCCNonlinearFunction{iip}(
559-
sys::NonlinearSystem, vscc, escc, cachesyms; eval_expression = false,
559+
sys::NonlinearSystem, _eqs, _dvs, _obs, cachesyms; eval_expression = false,
560560
eval_module = @__MODULE__, kwargs...) where {iip}
561-
dvs = unknowns(sys)
562561
ps = parameters(sys)
563562
rps = reorder_parameters(sys, ps)
564-
eqs = equations(sys)
565-
obs = observed(sys)
566563

567-
_dvs = dvs[vscc]
568-
_eqs = eqs[escc]
569-
obsidxs = observed_equations_used_by(sys, _eqs)
570-
_obs = obs[obsidxs]
571564
obs_assignments = [eq.lhs eq.rhs for eq in _obs]
572565

573566
cmap, cs = get_cmap(sys)
@@ -617,24 +610,46 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
617610

618611
_, u0, p = process_SciMLProblem(
619612
EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module, kwargs...)
620-
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), length(u0)))
621613

622-
subprobs = []
623614
explicitfuns = []
615+
nlfuns = []
616+
prevobsidxs = Int[]
617+
cachevars = []
618+
cacheexprs = []
624619
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
625-
oldvars = dvs[reduce(vcat, view(var_sccs, 1:(i - 1)); init = Int[])]
626-
if isempty(oldvars)
627-
push!(explicitfuns, (_...) -> nothing)
620+
# subset unknowns and equations
621+
_dvs = dvs[vscc]
622+
_eqs = eqs[escc]
623+
# get observed equations required by this SCC
624+
obsidxs = observed_equations_used_by(sys, _eqs)
625+
# the ones used by previous SCCs can be precomputed into the cache
626+
setdiff!(obsidxs, prevobsidxs)
627+
_obs = obs[obsidxs]
628+
629+
if isempty(cachevars)
630+
push!(explicitfuns, Returns(nothing))
628631
else
629632
solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1)))
630633
push!(explicitfuns,
631-
CacheWriter(sys, oldvars, solsyms; eval_expression, eval_module))
634+
CacheWriter(sys, cacheexprs, solsyms; eval_expression, eval_module))
635+
end
636+
f = SCCNonlinearFunction{iip}(
637+
sys, _eqs, _dvs, _obs, (cachevars,); eval_expression, eval_module, kwargs...)
638+
push!(nlfuns, f)
639+
append!(cachevars, _dvs)
640+
append!(cacheexprs, _dvs)
641+
for i in obsidxs
642+
push!(cachevars, obs[i].lhs)
643+
push!(cacheexprs, obs[i].rhs)
632644
end
633-
prob = NonlinearProblem(
634-
SCCNonlinearFunction{iip}(
635-
sys, vscc, escc, (oldvars,); eval_expression, eval_module, kwargs...),
636-
u0[vscc],
637-
p)
645+
append!(prevobsidxs, obsidxs)
646+
end
647+
648+
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), length(cachevars)))
649+
650+
subprobs = []
651+
for (f, vscc) in zip(nlfuns, var_sccs)
652+
prob = NonlinearProblem(f, u0[vscc], p)
638653
push!(subprobs, prob)
639654
end
640655

0 commit comments

Comments
 (0)