@@ -573,29 +573,37 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
573573 NonlinearLeastSquaresProblem {iip} (f, u0, p; filter_kwargs (kwargs)... )
574574end
575575
576+ const TypeT = Union{DataType, UnionAll}
577+
576578struct CacheWriter{F}
577579 fn:: F
578580end
579581
580582function (cw:: CacheWriter )(p, sols)
581- cw. fn (p. caches[ 1 ] , sols, p... )
583+ cw. fn (p. caches... , sols, p... )
582584end
583585
584- function CacheWriter (sys:: AbstractSystem , exprs, solsyms, obseqs:: Vector{Equation} ;
586+ function CacheWriter (sys:: AbstractSystem , buffer_types:: Vector{TypeT} ,
587+ exprs:: Dict{TypeT, Vector{Any}} , solsyms, obseqs:: Vector{Equation} ;
585588 eval_expression = false , eval_module = @__MODULE__ )
586589 ps = parameters (sys)
587590 rps = reorder_parameters (sys, ps)
588591 obs_assigns = [eq. lhs ← eq. rhs for eq in obseqs]
589592 cmap, cs = get_cmap (sys)
590593 cmap_assigns = [eq. lhs ← eq. rhs for eq in cmap]
594+
595+ outsyms = [Symbol (:out , i) for i in eachindex (buffer_types)]
596+ body = map (eachindex (buffer_types), buffer_types) do i, T
597+ Symbol (:tmp , i) ← SetArray (true , outsyms[i], get (exprs, T, []))
598+ end
591599 fn = Func (
592- [:out , DestructuredArgs (DestructuredArgs .(solsyms)),
600+ [outsyms ... , DestructuredArgs (DestructuredArgs .(solsyms)),
593601 DestructuredArgs .(rps)... ],
594602 [],
595- SetArray ( true , :out , exprs )
603+ Let (body , :() )
596604 ) |> wrap_assignments (false , obs_assigns)[2 ] |>
597605 wrap_parameter_dependencies (sys, false )[2 ] |>
598- wrap_array_vars (sys, exprs ; dvs = nothing , inputs = [])[2 ] |>
606+ wrap_array_vars (sys, [] ; dvs = nothing , inputs = [])[2 ] |>
599607 wrap_assignments (false , cmap_assigns)[2 ] |> toexpr
600608 return CacheWriter (eval_or_rgf (fn; eval_expression, eval_module))
601609end
@@ -677,8 +685,16 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
677685
678686 explicitfuns = []
679687 nlfuns = []
680- prevobsidxs = Int[]
681- cachesize = 0
688+ prevobsidxs = BlockArray (undef_blocks, Vector{Int}, Int[])
689+ # Cache buffer types and corresponding sizes. Stored as a pair of arrays instead of a
690+ # dict to maintain a consistent order of buffers across SCCs
691+ cachetypes = TypeT[]
692+ cachesizes = Int[]
693+ # explicitfun! related information for each SCC
694+ # We need to compute buffer sizes before doing any codegen
695+ scc_cachevars = Dict{TypeT, Vector{Any}}[]
696+ scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
697+ scc_eqs = Vector{Equation}[]
682698 for (i, (escc, vscc)) in enumerate (zip (eq_sccs, var_sccs))
683699 # subset unknowns and equations
684700 _dvs = dvs[vscc]
@@ -690,6 +706,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
690706 _obs = obs[obsidxs]
691707
692708 # get all subexpressions in the RHS which we can precompute in the cache
709+ # precomputed subexpressions should not contain `banned_vars`
693710 banned_vars = Set {Any} (vcat (_dvs, getproperty .(_obs, (:lhs ,))))
694711 for var in banned_vars
695712 iscall (var) || continue
@@ -706,37 +723,85 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
706723 _eqs[i]. rhs, banned_vars, state)
707724 end
708725
709- # cached variables and their corresponding expressions
710- cachevars = Any[obs[i]. lhs for i in prevobsidxs]
711- cacheexprs = Any[obs[i]. lhs for i in prevobsidxs]
726+ # map from symtype to cached variables and their expressions
727+ cachevars = Dict {Union{DataType, UnionAll}, Vector{Any}} ()
728+ cacheexprs = Dict {Union{DataType, UnionAll}, Vector{Any}} ()
729+ # observed of previous SCCs are in the cache
730+ # NOTE: When we get proper CSE, we can substitute these
731+ # and then use `subexpressions_not_involving_vars!`
732+ for i in prevobsidxs
733+ T = symtype (obs[i]. lhs)
734+ buf = get! (() -> Any[], cachevars, T)
735+ push! (buf, obs[i]. lhs)
736+
737+ buf = get! (() -> Any[], cacheexprs, T)
738+ push! (buf, obs[i]. lhs)
739+ end
740+
712741 for (k, v) in state
713- push! (cachevars, unwrap (v))
714- push! (cacheexprs, unwrap (k))
742+ k = unwrap (k)
743+ v = unwrap (v)
744+ T = symtype (k)
745+ buf = get! (() -> Any[], cachevars, T)
746+ push! (buf, v)
747+ buf = get! (() -> Any[], cacheexprs, T)
748+ push! (buf, k)
715749 end
716- cachesize = max (cachesize, length (cachevars))
750+
751+ # update the sizes of cache buffers
752+ for (T, buf) in cachevars
753+ idx = findfirst (isequal (T), cachetypes)
754+ if idx === nothing
755+ push! (cachetypes, T)
756+ push! (cachesizes, 0 )
757+ idx = lastindex (cachetypes)
758+ end
759+ cachesizes[idx] = max (cachesizes[idx], length (buf))
760+ end
761+
762+ push! (scc_cachevars, cachevars)
763+ push! (scc_cacheexprs, cacheexprs)
764+ push! (scc_eqs, _eqs)
765+ blockpush! (prevobsidxs, obsidxs)
766+ end
767+
768+ for (i, (escc, vscc)) in enumerate (zip (eq_sccs, var_sccs))
769+ _dvs = dvs[vscc]
770+ _eqs = scc_eqs[i]
771+ obsidxs = prevobsidxs[Block (i)]
772+ _prevobsidxs = reduce (vcat, blocks (prevobsidxs)[1 : (i - 1 )]; init = Int[])
773+ _obs = obs[obsidxs]
774+ cachevars = scc_cachevars[i]
775+ cacheexprs = scc_cacheexprs[i]
717776
718777 if isempty (cachevars)
719778 push! (explicitfuns, Returns (nothing ))
720779 else
721780 solsyms = getindex .((dvs,), view (var_sccs, 1 : (i - 1 )))
722781 push! (explicitfuns,
723- CacheWriter (sys, cacheexprs, solsyms, obs[prevobsidxs ];
782+ CacheWriter (sys, cachetypes, cacheexprs, solsyms, obs[_prevobsidxs ];
724783 eval_expression, eval_module))
725784 end
785+
786+ cachebufsyms = Tuple (map (cachetypes) do T
787+ get (cachevars, T, [])
788+ end )
726789 f = SCCNonlinearFunction {iip} (
727- sys, _eqs, _dvs, _obs, (cachevars,) ; eval_expression, eval_module, kwargs... )
790+ sys, _eqs, _dvs, _obs, cachebufsyms ; eval_expression, eval_module, kwargs... )
728791 push! (nlfuns, f)
729- append! (cachevars, _dvs)
730- append! (cacheexprs, _dvs)
731- for i in obsidxs
732- push! (cachevars, obs[i]. lhs)
733- push! (cacheexprs, obs[i]. rhs)
734- end
735- append! (prevobsidxs, obsidxs)
736792 end
737793
738- if cachesize != 0
739- p = rebuild_with_caches (p, BufferTemplate (eltype (u0), cachesize))
794+ if ! isempty (cachetypes)
795+ templates = map (cachetypes, cachesizes) do T, n
796+ # Real refers to `eltype(u0)`
797+ if T == Real
798+ T = eltype (u0)
799+ elseif T <: Array && eltype (T) == Real
800+ T = Array{eltype (u0), ndims (T)}
801+ end
802+ BufferTemplate (T, n)
803+ end
804+ p = rebuild_with_caches (p, templates... )
740805 end
741806
742807 subprobs = []
0 commit comments