@@ -491,32 +491,6 @@ function scalarize_varmap!(varmap::AbstractDict)
491491 return varmap
492492end
493493
494- struct GetUpdatedMTKParameters{G, S}
495- # `getu` functor which gets parameters that are unknowns during initialization
496- getpunknowns:: G
497- # `setu` functor which returns a modified MTKParameters using those parameters
498- setpunknowns:: S
499- end
500-
501- function (f:: GetUpdatedMTKParameters )(prob, initializesol)
502- p = parameter_values (prob)
503- p === nothing && return nothing
504- mtkp = copy (p)
505- f. setpunknowns (mtkp, f. getpunknowns (initializesol))
506- mtkp
507- end
508-
509- struct UpdateInitializeprob{G, S}
510- # `getu` functor which gets all values from prob
511- getvals:: G
512- # `setu` functor which updates initializeprob with values
513- setvals:: S
514- end
515-
516- function (f:: UpdateInitializeprob )(initializeprob, prob)
517- f. setvals (initializeprob, f. getvals (prob))
518- end
519-
520494function get_temporary_value (p, floatT = Float64)
521495 stype = symtype (unwrap (p))
522496 return if stype == Real
@@ -669,48 +643,89 @@ function concrete_getu(indp, syms::AbstractVector)
669643 return Base. Fix1 (reduce, vcat) ∘ getu (indp, split_syms)
670644end
671645
646+ """
647+ $(TYPEDSIGNATURES)
648+
649+ Given a source system `srcsys` and destination system `dstsys`, return a function that
650+ takes a value provider of `srcsys` and a value provider of `dstsys` and returns the
651+ `MTKParameters` object of the latter with values from the former.
652+
653+ # Keyword Arguments
654+ - `initials`: Whether to include the `Initial` parameters of `dstsys` among the values
655+ to be transferred.
656+ - `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
657+ """
658+ function get_mtkparameters_reconstructor (srcsys:: AbstractSystem , dstsys:: AbstractSystem ;
659+ initials = false , unwrap_initials = false , p_constructor = identity)
660+ # if we call `getu` on this (and it were able to handle empty tuples) we get the
661+ # fields of `MTKParameters` except caches.
662+ syms = reorder_parameters (
663+ dstsys, parameters (dstsys; initial_parameters = initials); flatten = false )
664+ # `dstsys` is an initialization system, do basically everything is a tunable
665+ # and tunables are a mix of different types in `srcsys`. No initials. Constants
666+ # are going to be constants in `srcsys`, as are `nonnumeric`.
667+
668+ # `syms[1]` is always the tunables because `srcsys` will have initials.
669+ tunable_syms = syms[1 ]
670+ tunable_getter = p_constructor ∘ concrete_getu (srcsys, tunable_syms)
671+ rest_getters = map (Base. tail (Base. tail (syms))) do buf
672+ if buf == ()
673+ return Returns (())
674+ else
675+ return Base. Fix1 (broadcast, p_constructor) ∘ getu (srcsys, buf)
676+ end
677+ end
678+ initials_getter = if initials
679+ initsyms = Vector {Any} (syms[2 ])
680+ allsyms = Set (all_symbols (srcsys))
681+ if unwrap_initials
682+ for i in eachindex (initsyms)
683+ sym = initsyms[i]
684+ innersym = if operation (sym) === getindex
685+ sym, idxs... = arguments (sym)
686+ only (arguments (sym))[idxs... ]
687+ else
688+ only (arguments (sym))
689+ end
690+ if innersym in allsyms
691+ initsyms[i] = innersym
692+ end
693+ end
694+ end
695+ p_constructor ∘ concrete_getu (srcsys, initsyms)
696+ else
697+ Returns (SizedVector {0, Float64} ())
698+ end
699+ getters = (tunable_getter, initials_getter, rest_getters... )
700+ getter = let getters = getters
701+ function _getter (valp, initprob)
702+ oldcache = parameter_values (initprob). caches
703+ MTKParameters (getters[1 ](valp), getters[2 ](valp), getters[3 ](valp),
704+ getters[4 ](valp), getters[5 ](valp), oldcache isa Tuple{} ? () :
705+ copy .(oldcache))
706+ end
707+ end
708+
709+ return getter
710+ end
711+
672712"""
673713 $(TYPEDSIGNATURES)
674714
675715Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of `dstsys`
676716with values from `srcsys`.
677717"""
678718function ReconstructInitializeprob (
679- srcsys:: AbstractSystem , dstsys:: AbstractSystem )
719+ srcsys:: AbstractSystem , dstsys:: AbstractSystem ; u0_constructor = identity, p_constructor = identity )
680720 @assert is_initializesystem (dstsys)
681- ugetter = getu (srcsys, unknowns (dstsys))
721+ ugetter = u0_constructor ∘ getu (srcsys, unknowns (dstsys))
682722 if is_split (dstsys)
683- # if we call `getu` on this (and it were able to handle empty tuples) we get the
684- # fields of `MTKParameters` except caches.
685- syms = reorder_parameters (dstsys, parameters (dstsys); flatten = false )
686- # `dstsys` is an initialization system, do basically everything is a tunable
687- # and tunables are a mix of different types in `srcsys`. No initials. Constants
688- # are going to be constants in `srcsys`, as are `nonnumeric`.
689-
690- # `syms[1]` is always the tunables because `srcsys` will have initials.
691- tunable_syms = syms[1 ]
692- tunable_getter = concrete_getu (srcsys, tunable_syms)
693- rest_getters = map (Base. tail (Base. tail (syms))) do buf
694- if buf == ()
695- return Returns (())
696- else
697- return getu (srcsys, buf)
698- end
699- end
700- getters = (tunable_getter, Returns (SizedVector {0, Float64} ()), rest_getters... )
701- pgetter = let getters = getters
702- function _getter (valp, initprob)
703- oldcache = parameter_values (initprob). caches
704- MTKParameters (getters[1 ](valp), getters[2 ](valp), getters[3 ](valp),
705- getters[4 ](valp), getters[5 ](valp), oldcache isa Tuple{} ? () :
706- copy .(oldcache))
707- end
708- end
723+ pgetter = get_mtkparameters_reconstructor (srcsys, dstsys; p_constructor)
709724 else
710725 syms = parameters (dstsys)
711- pgetter = let inner = concrete_getu (srcsys, syms)
726+ pgetter = let inner = concrete_getu (srcsys, syms), p_constructor = p_constructor
712727 function _getter2 (valp, initprob)
713- inner (valp)
728+ p_constructor ( inner (valp) )
714729 end
715730 end
716731 end
@@ -763,6 +778,54 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
763778 return u0, newp
764779end
765780
781+ """
782+ $(TYPEDSIGNATURES)
783+
784+ Given `sys` and its corresponding initialization system `initsys`, return the
785+ `initializeprobpmap` function in `OverrideInitData` for the systems.
786+ """
787+ function construct_initializeprobpmap (
788+ sys:: AbstractSystem , initsys:: AbstractSystem ; p_constructor = identity)
789+ @assert is_initializesystem (initsys)
790+ if is_split (sys)
791+ return let getter = get_mtkparameters_reconstructor (
792+ initsys, sys; initials = true , p_constructor)
793+ function initprobpmap_split (prob, initsol)
794+ getter (initsol, prob)
795+ end
796+ end
797+ else
798+ return let getter = getu (initsys, parameters (sys; initial_parameters = true )),
799+ p_constructor = p_constructor
800+
801+ function initprobpmap_nosplit (prob, initsol)
802+ return p_constructor (getter (initsol))
803+ end
804+ end
805+ end
806+ end
807+
808+ function get_scimlfn (valp)
809+ valp isa SciMLBase. AbstractSciMLFunction && return valp
810+ if hasmethod (symbolic_container, Tuple{typeof (valp)}) &&
811+ (sc = symbolic_container (valp)) != = valp
812+ return get_scimlfn (sc)
813+ end
814+ throw (ArgumentError (" SciMLFunction not found. This should never happen." ))
815+ end
816+
817+ """
818+ $(TYPEDSIGNATURES)
819+
820+ A function to be used as `update_initializeprob!` in `OverrideInitData`. Requires
821+ `is_update_oop = Val{true}` to be passed to `update_initializeprob!`.
822+ """
823+ function update_initializeprob! (initprob, prob)
824+ p = get_scimlfn (prob). initialization_data. metadata. oop_reconstruct_u0_p. getter (
825+ prob, initprob)
826+ return remake (initprob; p)
827+ end
828+
766829"""
767830 $(TYPEDEF)
768831
@@ -804,8 +867,8 @@ struct InitializationMetadata{R <: ReconstructInitializeprob, GUU, SIU}
804867 """
805868 get_updated_u0:: GUU
806869 """
807- A function which takes the `u0` of the problem and sets
808- `Initial.(unknowns(sys))`.
870+ A function which takes parameter object and `u0` of the problem and sets
871+ `Initial.(unknowns(sys))` in the former, returning the updated parameter object .
809872 """
810873 set_initial_unknowns!:: SIU
811874end
@@ -856,6 +919,38 @@ function (guu::GetUpdatedU0)(prob, initprob)
856919 return buffer
857920end
858921
922+ struct SetInitialUnknowns{S}
923+ setter!:: S
924+ end
925+
926+ function SetInitialUnknowns (sys:: AbstractSystem )
927+ return SetInitialUnknowns (setu (sys, Initial .(unknowns (sys))))
928+ end
929+
930+ function (siu:: SetInitialUnknowns )(p:: MTKParameters , u0)
931+ if ArrayInterface. ismutable (p. initials)
932+ siu. setter! (p, u0)
933+ else
934+ originalT = similar_type (p. initials)
935+ @set! p. initials = MVector {length(p.initials), eltype(p.initials)} (p. initials)
936+ siu. setter! (p, u0)
937+ @set! p. initials = originalT (p. initials)
938+ end
939+ return p
940+ end
941+
942+ function (siu:: SetInitialUnknowns )(p:: Vector , u0)
943+ if ArrayInterface. ismutable (p)
944+ siu. setter! (p, u0)
945+ else
946+ originalT = similar_type (p)
947+ p = MVector {length(p), eltype(p)} (p)
948+ siu. setter! (p, u0)
949+ p = originalT (p)
950+ end
951+ return p
952+ end
953+
859954"""
860955 $(TYPEDSIGNATURES)
861956
@@ -913,8 +1008,9 @@ function maybe_build_initialization_problem(
9131008 end
9141009 meta = InitializationMetadata (
9151010 u0map, pmap, guesses, Vector {Equation} (initialization_eqs),
916- use_scc, ReconstructInitializeprob (sys, initializeprob. f. sys),
917- get_initial_unknowns, setp (sys, Initial .(unknowns (sys))))
1011+ use_scc, ReconstructInitializeprob (
1012+ sys, initializeprob. f. sys; u0_constructor, p_constructor),
1013+ get_initial_unknowns, SetInitialUnknowns (sys))
9181014
9191015 if is_time_dependent (sys)
9201016 all_init_syms = Set (all_symbols (initializeprob))
@@ -930,20 +1026,16 @@ function maybe_build_initialization_problem(
9301026 if initializeprobmap === nothing && isempty (punknowns)
9311027 initializeprobpmap = nothing
9321028 else
933- allsyms = all_symbols (initializeprob)
934- initdvs = filter (x -> any (isequal (x), allsyms), unknowns (sys))
935- getpunknowns = getu (initializeprob, [punknowns; initdvs])
936- setpunknowns = setp (sys, [punknowns; Initial .(initdvs)])
937- initializeprobpmap = GetUpdatedMTKParameters (getpunknowns, setpunknowns)
1029+ initializeprobpmap = construct_initializeprobpmap (
1030+ sys, initializeprob. f. sys; p_constructor)
9381031 end
9391032
9401033 reqd_syms = parameter_symbols (initializeprob)
9411034 # we still want the `initialization_data` because it helps with `remake`
9421035 if initializeprobmap === nothing && initializeprobpmap === nothing
9431036 update_initializeprob! = nothing
9441037 else
945- update_initializeprob! = UpdateInitializeprob (
946- getu (sys, reqd_syms), setu (initializeprob, reqd_syms))
1038+ update_initializeprob! = ModelingToolkit. update_initializeprob!
9471039 end
9481040
9491041 for p in punknowns
@@ -967,7 +1059,7 @@ function maybe_build_initialization_problem(
9671059 return (;
9681060 initialization_data = SciMLBase. OverrideInitData (
9691061 initializeprob, update_initializeprob!, initializeprobmap,
970- initializeprobpmap; metadata = meta))
1062+ initializeprobpmap; metadata = meta, is_update_oop = Val{ true } ))
9711063end
9721064
9731065"""
0 commit comments