@@ -630,13 +630,17 @@ with promoted types.
630
630
631
631
$(TYPEDFIELDS)
632
632
"""
633
- struct ReconstructInitializeprob{G }
633
+ struct ReconstructInitializeprob{GP, GU }
634
634
"""
635
635
A function which when given the original problem and initialization problem, returns
636
636
the parameter object of the initialization problem with values copied from the
637
637
original.
638
638
"""
639
- getter:: G
639
+ pgetter:: GP
640
+ """
641
+ Given the original problem, return the `u0` of the initialization problem.
642
+ """
643
+ ugetter:: GU
640
644
end
641
645
642
646
"""
@@ -674,6 +678,7 @@ with values from `srcsys`.
674
678
function ReconstructInitializeprob (
675
679
srcsys:: AbstractSystem , dstsys:: AbstractSystem )
676
680
@assert is_initializesystem (dstsys)
681
+ ugetter = getu (srcsys, unknowns (dstsys))
677
682
if is_split (dstsys)
678
683
# if we call `getu` on this (and it were able to handle empty tuples) we get the
679
684
# fields of `MTKParameters` except caches.
@@ -693,7 +698,7 @@ function ReconstructInitializeprob(
693
698
end
694
699
end
695
700
getters = (tunable_getter, Returns (SizedVector {0, Float64} ()), rest_getters... )
696
- getter = let getters = getters
701
+ pgetter = let getters = getters
697
702
function _getter (valp, initprob)
698
703
oldcache = parameter_values (initprob). caches
699
704
MTKParameters (getters[1 ](valp), getters[2 ](valp), getters[3 ](valp),
@@ -703,13 +708,13 @@ function ReconstructInitializeprob(
703
708
end
704
709
else
705
710
syms = parameters (dstsys)
706
- getter = let inner = concrete_getu (srcsys, syms)
711
+ pgetter = let inner = concrete_getu (srcsys, syms)
707
712
function _getter2 (valp, initprob)
708
713
inner (valp)
709
714
end
710
715
end
711
716
end
712
- return ReconstructInitializeprob (getter )
717
+ return ReconstructInitializeprob (pgetter, ugetter )
713
718
end
714
719
715
720
"""
@@ -719,7 +724,7 @@ Copy values from `srcvalp` to `dstvalp`. Returns the new `u0` and `p`.
719
724
"""
720
725
function (rip:: ReconstructInitializeprob )(srcvalp, dstvalp)
721
726
# copy parameters
722
- newp = rip. getter (srcvalp, dstvalp)
727
+ newp = rip. pgetter (srcvalp, dstvalp)
723
728
# no `u0`, so no type-promotion
724
729
if state_values (dstvalp) === nothing
725
730
return nothing , newp
@@ -735,11 +740,10 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
735
740
elseif ! isempty (newp)
736
741
T = promote_type (eltype (newp), T)
737
742
end
743
+ u0 = rip. ugetter (srcvalp)
738
744
# and the eltype of the destination u0
739
- if T == eltype (state_values (dstvalp))
740
- u0 = state_values (dstvalp)
741
- elseif T != Union{}
742
- u0 = T .(state_values (dstvalp))
745
+ if T != eltype (u0) && T != Union{}
746
+ u0 = T .(u0)
743
747
end
744
748
# apply the promotion to tunables portion
745
749
buf, repack, alias = SciMLStructures. canonicalize (SciMLStructures. Tunable (), newp)
@@ -911,11 +915,13 @@ function maybe_build_initialization_problem(
911
915
punknowns = [p
912
916
for p in all_variable_symbols (initializeprob)
913
917
if is_parameter (sys, p)]
914
- if isempty (punknowns)
918
+ if initializeprobmap === nothing && isempty (punknowns)
915
919
initializeprobpmap = nothing
916
920
else
917
- getpunknowns = getu (initializeprob, punknowns)
918
- setpunknowns = setp (sys, punknowns)
921
+ allsyms = all_symbols (initializeprob)
922
+ initdvs = filter (x -> any (isequal (x), allsyms), unknowns (sys))
923
+ getpunknowns = getu (initializeprob, [punknowns; initdvs])
924
+ setpunknowns = setp (sys, [punknowns; Initial .(initdvs)])
919
925
initializeprobpmap = GetUpdatedMTKParameters (getpunknowns, setpunknowns)
920
926
end
921
927
0 commit comments