@@ -573,29 +573,37 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
573573 NonlinearLeastSquaresProblem {iip} (f, u0, p; filter_kwargs (kwargs)... )
574574end
575575
576+ const TypeT = Union{DataType, UnionAll}
577+
576578struct CacheWriter{F}
577579 fn:: F
578580end
579581
580582function (cw:: CacheWriter )(p, sols)
581- cw. fn (p. caches[ 1 ] , sols, p... )
583+ cw. fn (p. caches, sols, p... )
582584end
583585
584- function CacheWriter (sys:: AbstractSystem , exprs, solsyms, obseqs:: Vector{Equation} ;
586+ function CacheWriter (sys:: AbstractSystem , buffer_types:: Vector{TypeT} ,
587+ exprs:: Dict{TypeT, Vector{Any}} , solsyms, obseqs:: Vector{Equation} ;
585588 eval_expression = false , eval_module = @__MODULE__ )
586589 ps = parameters (sys)
587590 rps = reorder_parameters (sys, ps)
588591 obs_assigns = [eq. lhs ← eq. rhs for eq in obseqs]
589592 cmap, cs = get_cmap (sys)
590593 cmap_assigns = [eq. lhs ← eq. rhs for eq in cmap]
594+
595+ outsyms = [Symbol (:out , i) for i in eachindex (buffer_types)]
596+ body = map (eachindex (buffer_types), buffer_types) do i, T
597+ Symbol (:tmp , i) ← SetArray (true , :(out[$ i]), get (exprs, T, []))
598+ end
591599 fn = Func (
592600 [:out , DestructuredArgs (DestructuredArgs .(solsyms)),
593601 DestructuredArgs .(rps)... ],
594602 [],
595- SetArray ( true , :out , exprs )
603+ Let (body , :() )
596604 ) |> wrap_assignments (false , obs_assigns)[2 ] |>
597605 wrap_parameter_dependencies (sys, false )[2 ] |>
598- wrap_array_vars (sys, exprs ; dvs = nothing , inputs = [])[2 ] |>
606+ wrap_array_vars (sys, [] ; dvs = nothing , inputs = [])[2 ] |>
599607 wrap_assignments (false , cmap_assigns)[2 ] |> toexpr
600608 return CacheWriter (eval_or_rgf (fn; eval_expression, eval_module))
601609end
@@ -677,8 +685,17 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
677685
678686 explicitfuns = []
679687 nlfuns = []
680- prevobsidxs = Int[]
681- cachesize = 0
688+ prevobsidxs = BlockArray (undef_blocks, Vector{Int}, Int[])
689+ # Cache buffer types and corresponding sizes. Stored as a pair of arrays instead of a
690+ # dict to maintain a consistent order of buffers across SCCs
691+ cachetypes = TypeT[]
692+ cachesizes = Int[]
693+ # explicitfun! related information for each SCC
694+ # We need to compute buffer sizes before doing any codegen
695+ scc_cachevars = Dict{TypeT, Vector{Any}}[]
696+ scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
697+ scc_eqs = Vector{Equation}[]
698+ scc_obs = Vector{Equation}[]
682699 for (i, (escc, vscc)) in enumerate (zip (eq_sccs, var_sccs))
683700 # subset unknowns and equations
684701 _dvs = dvs[vscc]
@@ -690,11 +707,10 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
690707 _obs = obs[obsidxs]
691708
692709 # get all subexpressions in the RHS which we can precompute in the cache
710+ # precomputed subexpressions should not contain `banned_vars`
693711 banned_vars = Set {Any} (vcat (_dvs, getproperty .(_obs, (:lhs ,))))
694- for var in banned_vars
695- iscall (var) || continue
696- operation (var) === getindex || continue
697- push! (banned_vars, arguments (var)[1 ])
712+ filter! (banned_vars) do var
713+ symbolic_type (var) != ArraySymbolic () || all (x -> var[i] in banned_vars, eachindex (var))
698714 end
699715 state = Dict ()
700716 for i in eachindex (_obs)
@@ -706,37 +722,85 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
706722 _eqs[i]. rhs, banned_vars, state)
707723 end
708724
709- # cached variables and their corresponding expressions
710- cachevars = Any[obs[i]. lhs for i in prevobsidxs]
711- cacheexprs = Any[obs[i]. lhs for i in prevobsidxs]
725+ # map from symtype to cached variables and their expressions
726+ cachevars = Dict {Union{DataType, UnionAll}, Vector{Any}} ()
727+ cacheexprs = Dict {Union{DataType, UnionAll}, Vector{Any}} ()
728+ # observed of previous SCCs are in the cache
729+ # NOTE: When we get proper CSE, we can substitute these
730+ # and then use `subexpressions_not_involving_vars!`
731+ for i in prevobsidxs
732+ T = symtype (obs[i]. lhs)
733+ buf = get! (() -> Any[], cachevars, T)
734+ push! (buf, obs[i]. lhs)
735+
736+ buf = get! (() -> Any[], cacheexprs, T)
737+ push! (buf, obs[i]. lhs)
738+ end
739+
712740 for (k, v) in state
713- push! (cachevars, unwrap (v))
714- push! (cacheexprs, unwrap (k))
741+ k = unwrap (k)
742+ v = unwrap (v)
743+ T = symtype (k)
744+ buf = get! (() -> Any[], cachevars, T)
745+ push! (buf, v)
746+ buf = get! (() -> Any[], cacheexprs, T)
747+ push! (buf, k)
715748 end
716- cachesize = max (cachesize, length (cachevars))
749+
750+ # update the sizes of cache buffers
751+ for (T, buf) in cachevars
752+ idx = findfirst (isequal (T), cachetypes)
753+ if idx === nothing
754+ push! (cachetypes, T)
755+ push! (cachesizes, 0 )
756+ idx = lastindex (cachetypes)
757+ end
758+ cachesizes[idx] = max (cachesizes[idx], length (buf))
759+ end
760+
761+ push! (scc_cachevars, cachevars)
762+ push! (scc_cacheexprs, cacheexprs)
763+ push! (scc_eqs, _eqs)
764+ push! (scc_obs, _obs)
765+ blockpush! (prevobsidxs, obsidxs)
766+ end
767+
768+ for (i, (escc, vscc)) in enumerate (zip (eq_sccs, var_sccs))
769+ _dvs = dvs[vscc]
770+ _eqs = scc_eqs[i]
771+ _prevobsidxs = reduce (vcat, blocks (prevobsidxs)[1 : (i - 1 )]; init = Int[])
772+ _obs = scc_obs[i]
773+ cachevars = scc_cachevars[i]
774+ cacheexprs = scc_cacheexprs[i]
717775
718776 if isempty (cachevars)
719777 push! (explicitfuns, Returns (nothing ))
720778 else
721779 solsyms = getindex .((dvs,), view (var_sccs, 1 : (i - 1 )))
722780 push! (explicitfuns,
723- CacheWriter (sys, cacheexprs, solsyms, obs[prevobsidxs ];
781+ CacheWriter (sys, cachetypes, cacheexprs, solsyms, obs[_prevobsidxs ];
724782 eval_expression, eval_module))
725783 end
784+
785+ cachebufsyms = Tuple (map (cachetypes) do T
786+ get (cachevars, T, [])
787+ end )
726788 f = SCCNonlinearFunction {iip} (
727- sys, _eqs, _dvs, _obs, (cachevars,) ; eval_expression, eval_module, kwargs... )
789+ sys, _eqs, _dvs, _obs, cachebufsyms ; eval_expression, eval_module, kwargs... )
728790 push! (nlfuns, f)
729- append! (cachevars, _dvs)
730- append! (cacheexprs, _dvs)
731- for i in obsidxs
732- push! (cachevars, obs[i]. lhs)
733- push! (cacheexprs, obs[i]. rhs)
734- end
735- append! (prevobsidxs, obsidxs)
736791 end
737792
738- if cachesize != 0
739- p = rebuild_with_caches (p, BufferTemplate (eltype (u0), cachesize))
793+ if ! isempty (cachetypes)
794+ templates = map (cachetypes, cachesizes) do T, n
795+ # Real refers to `eltype(u0)`
796+ if T == Real
797+ T = eltype (u0)
798+ elseif T <: Array && eltype (T) == Real
799+ T = Array{eltype (u0), ndims (T)}
800+ end
801+ BufferTemplate (T, n)
802+ end
803+ p = rebuild_with_caches (p, templates... )
740804 end
741805
742806 subprobs = []
0 commit comments