@@ -630,13 +630,17 @@ with promoted types.
630630
631631$(TYPEDFIELDS)
632632"""
633- struct ReconstructInitializeprob{G }
633+ struct ReconstructInitializeprob{GP, GU }
634634 """
635635 A function which when given the original problem and initialization problem, returns
636636 the parameter object of the initialization problem with values copied from the
637637 original.
638638 """
639- getter:: G
639+ pgetter:: GP
640+ """
641+ Given the original problem, return the `u0` of the initialization problem.
642+ """
643+ ugetter:: GU
640644end
641645
642646"""
@@ -674,6 +678,7 @@ with values from `srcsys`.
674678function ReconstructInitializeprob (
675679 srcsys:: AbstractSystem , dstsys:: AbstractSystem )
676680 @assert is_initializesystem (dstsys)
681+ ugetter = getu (srcsys, unknowns (dstsys))
677682 if is_split (dstsys)
678683 # if we call `getu` on this (and it were able to handle empty tuples) we get the
679684 # fields of `MTKParameters` except caches.
@@ -693,7 +698,7 @@ function ReconstructInitializeprob(
693698 end
694699 end
695700 getters = (tunable_getter, Returns (SizedVector {0, Float64} ()), rest_getters... )
696- getter = let getters = getters
701+ pgetter = let getters = getters
697702 function _getter (valp, initprob)
698703 oldcache = parameter_values (initprob). caches
699704 MTKParameters (getters[1 ](valp), getters[2 ](valp), getters[3 ](valp),
@@ -703,13 +708,13 @@ function ReconstructInitializeprob(
703708 end
704709 else
705710 syms = parameters (dstsys)
706- getter = let inner = concrete_getu (srcsys, syms)
711+ pgetter = let inner = concrete_getu (srcsys, syms)
707712 function _getter2 (valp, initprob)
708713 inner (valp)
709714 end
710715 end
711716 end
712- return ReconstructInitializeprob (getter )
717+ return ReconstructInitializeprob (pgetter, ugetter )
713718end
714719
715720"""
@@ -719,7 +724,7 @@ Copy values from `srcvalp` to `dstvalp`. Returns the new `u0` and `p`.
719724"""
720725function (rip:: ReconstructInitializeprob )(srcvalp, dstvalp)
721726 # copy parameters
722- newp = rip. getter (srcvalp, dstvalp)
727+ newp = rip. pgetter (srcvalp, dstvalp)
723728 # no `u0`, so no type-promotion
724729 if state_values (dstvalp) === nothing
725730 return nothing , newp
@@ -735,11 +740,10 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
735740 elseif ! isempty (newp)
736741 T = promote_type (eltype (newp), T)
737742 end
743+ u0 = rip. ugetter (srcvalp)
738744 # and the eltype of the destination u0
739- if T == eltype (state_values (dstvalp))
740- u0 = state_values (dstvalp)
741- elseif T != Union{}
742- u0 = T .(state_values (dstvalp))
745+ if T != eltype (u0) && T != Union{}
746+ u0 = T .(u0)
743747 end
744748 # apply the promotion to tunables portion
745749 buf, repack, alias = SciMLStructures. canonicalize (SciMLStructures. Tunable (), newp)
@@ -911,11 +915,13 @@ function maybe_build_initialization_problem(
911915 punknowns = [p
912916 for p in all_variable_symbols (initializeprob)
913917 if is_parameter (sys, p)]
914- if isempty (punknowns)
918+ if initializeprobmap === nothing && isempty (punknowns)
915919 initializeprobpmap = nothing
916920 else
917- getpunknowns = getu (initializeprob, punknowns)
918- setpunknowns = setp (sys, punknowns)
921+ allsyms = all_symbols (initializeprob)
922+ initdvs = filter (x -> any (isequal (x), allsyms), unknowns (sys))
923+ getpunknowns = getu (initializeprob, [punknowns; initdvs])
924+ setpunknowns = setp (sys, [punknowns; Initial .(initdvs)])
919925 initializeprobpmap = GetUpdatedMTKParameters (getpunknowns, setpunknowns)
920926 end
921927
0 commit comments