Skip to content

Commit 5c9e148

Browse files
feat: support caching of different types of subexpressions in SCCNonlinearProblem
1 parent eff5907 commit 5c9e148

File tree

2 files changed

+91
-25
lines changed

2 files changed

+91
-25
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ import SCCNonlinearSolve
5454
using Reexport
5555
using RecursiveArrayTools
5656
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
57-
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
57+
import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, blockpush!,
58+
undef_blocks, blocks
5859
import CommonSolve
5960
import EnumX
6061

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 89 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -573,29 +573,37 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
573573
NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)
574574
end
575575

576+
const TypeT = Union{DataType, UnionAll}
577+
576578
struct CacheWriter{F}
577579
fn::F
578580
end
579581

580582
function (cw::CacheWriter)(p, sols)
581-
cw.fn(p.caches[1], sols, p...)
583+
cw.fn(p.caches..., sols, p...)
582584
end
583585

584-
function CacheWriter(sys::AbstractSystem, exprs, solsyms, obseqs::Vector{Equation};
586+
function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT},
587+
exprs::Dict{TypeT, Vector{Any}}, solsyms, obseqs::Vector{Equation};
585588
eval_expression = false, eval_module = @__MODULE__)
586589
ps = parameters(sys)
587590
rps = reorder_parameters(sys, ps)
588591
obs_assigns = [eq.lhs eq.rhs for eq in obseqs]
589592
cmap, cs = get_cmap(sys)
590593
cmap_assigns = [eq.lhs eq.rhs for eq in cmap]
594+
595+
outsyms = [Symbol(:out, i) for i in eachindex(buffer_types)]
596+
body = map(eachindex(buffer_types), buffer_types) do i, T
597+
Symbol(:tmp, i) SetArray(true, outsyms[i], get(exprs, T, []))
598+
end
591599
fn = Func(
592-
[:out, DestructuredArgs(DestructuredArgs.(solsyms)),
600+
[outsyms..., DestructuredArgs(DestructuredArgs.(solsyms)),
593601
DestructuredArgs.(rps)...],
594602
[],
595-
SetArray(true, :out, exprs)
603+
Let(body, :())
596604
) |> wrap_assignments(false, obs_assigns)[2] |>
597605
wrap_parameter_dependencies(sys, false)[2] |>
598-
wrap_array_vars(sys, exprs; dvs = nothing, inputs = [])[2] |>
606+
wrap_array_vars(sys, []; dvs = nothing, inputs = [])[2] |>
599607
wrap_assignments(false, cmap_assigns)[2] |> toexpr
600608
return CacheWriter(eval_or_rgf(fn; eval_expression, eval_module))
601609
end
@@ -677,8 +685,16 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
677685

678686
explicitfuns = []
679687
nlfuns = []
680-
prevobsidxs = Int[]
681-
cachesize = 0
688+
prevobsidxs = BlockArray(undef_blocks, Vector{Int}, Int[])
689+
# Cache buffer types and corresponding sizes. Stored as a pair of arrays instead of a
690+
# dict to maintain a consistent order of buffers across SCCs
691+
cachetypes = TypeT[]
692+
cachesizes = Int[]
693+
# explicitfun! related information for each SCC
694+
# We need to compute buffer sizes before doing any codegen
695+
scc_cachevars = Dict{TypeT, Vector{Any}}[]
696+
scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
697+
scc_eqs = Vector{Equation}[]
682698
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
683699
# subset unknowns and equations
684700
_dvs = dvs[vscc]
@@ -690,6 +706,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
690706
_obs = obs[obsidxs]
691707

692708
# get all subexpressions in the RHS which we can precompute in the cache
709+
# precomputed subexpressions should not contain `banned_vars`
693710
banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,))))
694711
for var in banned_vars
695712
iscall(var) || continue
@@ -706,37 +723,85 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
706723
_eqs[i].rhs, banned_vars, state)
707724
end
708725

709-
# cached variables and their corresponding expressions
710-
cachevars = Any[obs[i].lhs for i in prevobsidxs]
711-
cacheexprs = Any[obs[i].lhs for i in prevobsidxs]
726+
# map from symtype to cached variables and their expressions
727+
cachevars = Dict{Union{DataType, UnionAll}, Vector{Any}}()
728+
cacheexprs = Dict{Union{DataType, UnionAll}, Vector{Any}}()
729+
# observed of previous SCCs are in the cache
730+
# NOTE: When we get proper CSE, we can substitute these
731+
# and then use `subexpressions_not_involving_vars!`
732+
for i in prevobsidxs
733+
T = symtype(obs[i].lhs)
734+
buf = get!(() -> Any[], cachevars, T)
735+
push!(buf, obs[i].lhs)
736+
737+
buf = get!(() -> Any[], cacheexprs, T)
738+
push!(buf, obs[i].lhs)
739+
end
740+
712741
for (k, v) in state
713-
push!(cachevars, unwrap(v))
714-
push!(cacheexprs, unwrap(k))
742+
k = unwrap(k)
743+
v = unwrap(v)
744+
T = symtype(k)
745+
buf = get!(() -> Any[], cachevars, T)
746+
push!(buf, v)
747+
buf = get!(() -> Any[], cacheexprs, T)
748+
push!(buf, k)
715749
end
716-
cachesize = max(cachesize, length(cachevars))
750+
751+
# update the sizes of cache buffers
752+
for (T, buf) in cachevars
753+
idx = findfirst(isequal(T), cachetypes)
754+
if idx === nothing
755+
push!(cachetypes, T)
756+
push!(cachesizes, 0)
757+
idx = lastindex(cachetypes)
758+
end
759+
cachesizes[idx] = max(cachesizes[idx], length(buf))
760+
end
761+
762+
push!(scc_cachevars, cachevars)
763+
push!(scc_cacheexprs, cacheexprs)
764+
push!(scc_eqs, _eqs)
765+
blockpush!(prevobsidxs, obsidxs)
766+
end
767+
768+
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
769+
_dvs = dvs[vscc]
770+
_eqs = scc_eqs[i]
771+
obsidxs = prevobsidxs[Block(i)]
772+
_prevobsidxs = reduce(vcat, blocks(prevobsidxs)[1:(i - 1)]; init = Int[])
773+
_obs = obs[obsidxs]
774+
cachevars = scc_cachevars[i]
775+
cacheexprs = scc_cacheexprs[i]
717776

718777
if isempty(cachevars)
719778
push!(explicitfuns, Returns(nothing))
720779
else
721780
solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1)))
722781
push!(explicitfuns,
723-
CacheWriter(sys, cacheexprs, solsyms, obs[prevobsidxs];
782+
CacheWriter(sys, cachetypes, cacheexprs, solsyms, obs[_prevobsidxs];
724783
eval_expression, eval_module))
725784
end
785+
786+
cachebufsyms = Tuple(map(cachetypes) do T
787+
get(cachevars, T, [])
788+
end)
726789
f = SCCNonlinearFunction{iip}(
727-
sys, _eqs, _dvs, _obs, (cachevars,); eval_expression, eval_module, kwargs...)
790+
sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, kwargs...)
728791
push!(nlfuns, f)
729-
append!(cachevars, _dvs)
730-
append!(cacheexprs, _dvs)
731-
for i in obsidxs
732-
push!(cachevars, obs[i].lhs)
733-
push!(cacheexprs, obs[i].rhs)
734-
end
735-
append!(prevobsidxs, obsidxs)
736792
end
737793

738-
if cachesize != 0
739-
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), cachesize))
794+
if !isempty(cachetypes)
795+
templates = map(cachetypes, cachesizes) do T, n
796+
# Real refers to `eltype(u0)`
797+
if T == Real
798+
T = eltype(u0)
799+
elseif T <: Array && eltype(T) == Real
800+
T = Array{eltype(u0), ndims(T)}
801+
end
802+
BufferTemplate(T, n)
803+
end
804+
p = rebuild_with_caches(p, templates...)
740805
end
741806

742807
subprobs = []

0 commit comments

Comments
 (0)