Skip to content

Commit 6aa661d

Browse files
refactor: store getu function in InitializationMetadata
1 parent eb37a14 commit 6aa661d

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,17 +592,26 @@ end
592592
function SciMLBase.late_binding_update_u0_p(
593593
prob, sys::AbstractSystem, u0, p, t0, newu0, newp)
594594
supports_initialization(sys) || return newu0, newp
595+
596+
initdata = prob.f.initialization_data
597+
meta = initdata === nothing ? nothing : initdata.metadata
595598
# If the user passes `p` to `remake` but not `u0` and `u0` isn't empty,
596599
# and if the system supports initialization (so it has initial parameters),
597600
# and if the initialization solves for `u0`,
598601
# THEN copy the values of `Initial`s to `newu0`.
599602
if u0 === missing
600-
if newu0 !== nothing && p !== missing && supports_initialization(sys) && prob.f.initialization_data !== nothing && prob.f.initialization_data.initializeprobmap !== nothing
603+
if newu0 !== nothing && p !== missing && supports_initialization(sys) &&
604+
initdata !== nothing && initdata.initializeprobmap !== nothing
605+
getter = if meta isa InitializationMetadata
606+
meta.get_initial_unknowns
607+
else
608+
getu(sys, Initial.(unknowns(sys)))
609+
end
601610
if ArrayInterface.ismutable(newu0)
602-
copyto!(newu0, getu(sys, Initial.(unknowns(sys)))(newp))
611+
copyto!(newu0, getter(newp))
603612
else
604613
T = StaticArrays.similar_type(newu0)
605-
newu0 = T(getu(sys, Initial.(unknowns(sys)))(newp))
614+
newu0 = T(getter(newp))
606615
end
607616
end
608617
return newu0, newp
@@ -613,7 +622,6 @@ function SciMLBase.late_binding_update_u0_p(
613622
# if `p` is not provided or is symbolic
614623
p === missing || eltype(p) <: Pair || return newu0, newp
615624
(newu0 === nothing || isempty(newu0)) && return newu0, newp
616-
initdata = prob.f.initialization_data
617625
initdata === nothing && return newu0, newp
618626
meta = initdata.metadata
619627
meta isa InitializationMetadata || return newu0, newp

src/systems/problem_utils.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ properly.
769769
770770
$(TYPEDFIELDS)
771771
"""
772-
struct InitializationMetadata{R <: ReconstructInitializeprob, SIU}
772+
struct InitializationMetadata{R <: ReconstructInitializeprob, GIU, SIU}
773773
"""
774774
The `u0map` used to construct the initialization.
775775
"""
@@ -796,6 +796,11 @@ struct InitializationMetadata{R <: ReconstructInitializeprob, SIU}
796796
"""
797797
oop_reconstruct_u0_p::R
798798
"""
799+
A function which takes the parameter object of the problem and returns
800+
`Initial.(unknowns(sys))`.
801+
"""
802+
get_initial_unknowns::GIU
803+
"""
799804
A function which takes the `u0` of the problem and sets
800805
`Initial.(unknowns(sys))`.
801806
"""
@@ -843,7 +848,7 @@ function maybe_build_initialization_problem(
843848
meta = InitializationMetadata(
844849
u0map, pmap, guesses, Vector{Equation}(initialization_eqs),
845850
use_scc, ReconstructInitializeprob(sys, initializeprob.f.sys),
846-
setp(sys, Initial.(unknowns(sys))))
851+
getp(sys, Initial.(unknowns(sys))), setp(sys, Initial.(unknowns(sys))))
847852

848853
if is_time_dependent(sys)
849854
all_init_syms = Set(all_symbols(initializeprob))

0 commit comments

Comments
 (0)