@@ -573,29 +573,37 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
573
573
NonlinearLeastSquaresProblem {iip} (f, u0, p; filter_kwargs (kwargs)... )
574
574
end
575
575
576
+ const TypeT = Union{DataType, UnionAll}
577
+
576
578
struct CacheWriter{F}
577
579
fn:: F
578
580
end
579
581
580
582
function (cw:: CacheWriter )(p, sols)
581
- cw. fn (p. caches[ 1 ] , sols, p... )
583
+ cw. fn (p. caches... , sols, p... )
582
584
end
583
585
584
- function CacheWriter (sys:: AbstractSystem , exprs, solsyms, obseqs:: Vector{Equation} ;
586
+ function CacheWriter (sys:: AbstractSystem , buffer_types:: Vector{TypeT} ,
587
+ exprs:: Dict{TypeT, Vector{Any}} , solsyms, obseqs:: Vector{Equation} ;
585
588
eval_expression = false , eval_module = @__MODULE__ )
586
589
ps = parameters (sys)
587
590
rps = reorder_parameters (sys, ps)
588
591
obs_assigns = [eq. lhs ← eq. rhs for eq in obseqs]
589
592
cmap, cs = get_cmap (sys)
590
593
cmap_assigns = [eq. lhs ← eq. rhs for eq in cmap]
594
+
595
+ outsyms = [Symbol (:out , i) for i in eachindex (buffer_types)]
596
+ body = map (eachindex (buffer_types), buffer_types) do i, T
597
+ Symbol (:tmp , i) ← SetArray (true , outsyms[i], get (exprs, T, []))
598
+ end
591
599
fn = Func (
592
- [:out , DestructuredArgs (DestructuredArgs .(solsyms)),
600
+ [outsyms ... , DestructuredArgs (DestructuredArgs .(solsyms)),
593
601
DestructuredArgs .(rps)... ],
594
602
[],
595
- SetArray ( true , :out , exprs )
603
+ Let (body , :() )
596
604
) |> wrap_assignments (false , obs_assigns)[2 ] |>
597
605
wrap_parameter_dependencies (sys, false )[2 ] |>
598
- wrap_array_vars (sys, exprs ; dvs = nothing , inputs = [])[2 ] |>
606
+ wrap_array_vars (sys, [] ; dvs = nothing , inputs = [])[2 ] |>
599
607
wrap_assignments (false , cmap_assigns)[2 ] |> toexpr
600
608
return CacheWriter (eval_or_rgf (fn; eval_expression, eval_module))
601
609
end
@@ -677,8 +685,16 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
677
685
678
686
explicitfuns = []
679
687
nlfuns = []
680
- prevobsidxs = Int[]
681
- cachesize = 0
688
+ prevobsidxs = BlockArray (undef_blocks, Vector{Int}, Int[])
689
+ # Cache buffer types and corresponding sizes. Stored as a pair of arrays instead of a
690
+ # dict to maintain a consistent order of buffers across SCCs
691
+ cachetypes = TypeT[]
692
+ cachesizes = Int[]
693
+ # explicitfun! related information for each SCC
694
+ # We need to compute buffer sizes before doing any codegen
695
+ scc_cachevars = Dict{TypeT, Vector{Any}}[]
696
+ scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
697
+ scc_eqs = Vector{Equation}[]
682
698
for (i, (escc, vscc)) in enumerate (zip (eq_sccs, var_sccs))
683
699
# subset unknowns and equations
684
700
_dvs = dvs[vscc]
@@ -690,6 +706,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
690
706
_obs = obs[obsidxs]
691
707
692
708
# get all subexpressions in the RHS which we can precompute in the cache
709
+ # precomputed subexpressions should not contain `banned_vars`
693
710
banned_vars = Set {Any} (vcat (_dvs, getproperty .(_obs, (:lhs ,))))
694
711
for var in banned_vars
695
712
iscall (var) || continue
@@ -706,37 +723,85 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
706
723
_eqs[i]. rhs, banned_vars, state)
707
724
end
708
725
709
- # cached variables and their corresponding expressions
710
- cachevars = Any[obs[i]. lhs for i in prevobsidxs]
711
- cacheexprs = Any[obs[i]. lhs for i in prevobsidxs]
726
+ # map from symtype to cached variables and their expressions
727
+ cachevars = Dict {Union{DataType, UnionAll}, Vector{Any}} ()
728
+ cacheexprs = Dict {Union{DataType, UnionAll}, Vector{Any}} ()
729
+ # observed of previous SCCs are in the cache
730
+ # NOTE: When we get proper CSE, we can substitute these
731
+ # and then use `subexpressions_not_involving_vars!`
732
+ for i in prevobsidxs
733
+ T = symtype (obs[i]. lhs)
734
+ buf = get! (() -> Any[], cachevars, T)
735
+ push! (buf, obs[i]. lhs)
736
+
737
+ buf = get! (() -> Any[], cacheexprs, T)
738
+ push! (buf, obs[i]. lhs)
739
+ end
740
+
712
741
for (k, v) in state
713
- push! (cachevars, unwrap (v))
714
- push! (cacheexprs, unwrap (k))
742
+ k = unwrap (k)
743
+ v = unwrap (v)
744
+ T = symtype (k)
745
+ buf = get! (() -> Any[], cachevars, T)
746
+ push! (buf, v)
747
+ buf = get! (() -> Any[], cacheexprs, T)
748
+ push! (buf, k)
715
749
end
716
- cachesize = max (cachesize, length (cachevars))
750
+
751
+ # update the sizes of cache buffers
752
+ for (T, buf) in cachevars
753
+ idx = findfirst (isequal (T), cachetypes)
754
+ if idx === nothing
755
+ push! (cachetypes, T)
756
+ push! (cachesizes, 0 )
757
+ idx = lastindex (cachetypes)
758
+ end
759
+ cachesizes[idx] = max (cachesizes[idx], length (buf))
760
+ end
761
+
762
+ push! (scc_cachevars, cachevars)
763
+ push! (scc_cacheexprs, cacheexprs)
764
+ push! (scc_eqs, _eqs)
765
+ blockpush! (prevobsidxs, obsidxs)
766
+ end
767
+
768
+ for (i, (escc, vscc)) in enumerate (zip (eq_sccs, var_sccs))
769
+ _dvs = dvs[vscc]
770
+ _eqs = scc_eqs[i]
771
+ obsidxs = prevobsidxs[Block (i)]
772
+ _prevobsidxs = reduce (vcat, blocks (prevobsidxs)[1 : (i - 1 )]; init = Int[])
773
+ _obs = obs[obsidxs]
774
+ cachevars = scc_cachevars[i]
775
+ cacheexprs = scc_cacheexprs[i]
717
776
718
777
if isempty (cachevars)
719
778
push! (explicitfuns, Returns (nothing ))
720
779
else
721
780
solsyms = getindex .((dvs,), view (var_sccs, 1 : (i - 1 )))
722
781
push! (explicitfuns,
723
- CacheWriter (sys, cacheexprs, solsyms, obs[prevobsidxs ];
782
+ CacheWriter (sys, cachetypes, cacheexprs, solsyms, obs[_prevobsidxs ];
724
783
eval_expression, eval_module))
725
784
end
785
+
786
+ cachebufsyms = Tuple (map (cachetypes) do T
787
+ get (cachevars, T, [])
788
+ end )
726
789
f = SCCNonlinearFunction {iip} (
727
- sys, _eqs, _dvs, _obs, (cachevars,) ; eval_expression, eval_module, kwargs... )
790
+ sys, _eqs, _dvs, _obs, cachebufsyms ; eval_expression, eval_module, kwargs... )
728
791
push! (nlfuns, f)
729
- append! (cachevars, _dvs)
730
- append! (cacheexprs, _dvs)
731
- for i in obsidxs
732
- push! (cachevars, obs[i]. lhs)
733
- push! (cacheexprs, obs[i]. rhs)
734
- end
735
- append! (prevobsidxs, obsidxs)
736
792
end
737
793
738
- if cachesize != 0
739
- p = rebuild_with_caches (p, BufferTemplate (eltype (u0), cachesize))
794
+ if ! isempty (cachetypes)
795
+ templates = map (cachetypes, cachesizes) do T, n
796
+ # Real refers to `eltype(u0)`
797
+ if T == Real
798
+ T = eltype (u0)
799
+ elseif T <: Array && eltype (T) == Real
800
+ T = Array{eltype (u0), ndims (T)}
801
+ end
802
+ BufferTemplate (T, n)
803
+ end
804
+ p = rebuild_with_caches (p, templates... )
740
805
end
741
806
742
807
subprobs = []
0 commit comments