Skip to content

Commit c385c11

Browse files
refactor: store getu function in InitializationMetadata
1 parent 67ef30c commit c385c11

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -588,17 +588,26 @@ end
588588
function SciMLBase.late_binding_update_u0_p(
589589
prob, sys::AbstractSystem, u0, p, t0, newu0, newp)
590590
supports_initialization(sys) || return newu0, newp
591+
592+
initdata = prob.f.initialization_data
593+
meta = initdata === nothing ? nothing : initdata.metadata
591594
# If the user passes `p` to `remake` but not `u0` and `u0` isn't empty,
592595
# and if the system supports initialization (so it has initial parameters),
593596
# and if the initialization solves for `u0`,
594597
# THEN copy the values of `Initial`s to `newu0`.
595598
if u0 === missing
596-
if newu0 !== nothing && p !== missing && supports_initialization(sys) && prob.f.initialization_data !== nothing && prob.f.initialization_data.initializeprobmap !== nothing
599+
if newu0 !== nothing && p !== missing && supports_initialization(sys) &&
600+
initdata !== nothing && initdata.initializeprobmap !== nothing
601+
getter = if meta isa InitializationMetadata
602+
meta.get_initial_unknowns
603+
else
604+
getu(sys, Initial.(unknowns(sys)))
605+
end
597606
if ArrayInterface.ismutable(newu0)
598-
copyto!(newu0, getu(sys, Initial.(unknowns(sys)))(newp))
607+
copyto!(newu0, getter(newp))
599608
else
600609
T = StaticArrays.similar_type(newu0)
601-
newu0 = T(getu(sys, Initial.(unknowns(sys)))(newp))
610+
newu0 = T(getter(newp))
602611
end
603612
end
604613
return newu0, newp
@@ -609,7 +618,6 @@ function SciMLBase.late_binding_update_u0_p(
609618
# if `p` is not provided or is symbolic
610619
p === missing || eltype(p) <: Pair || return newu0, newp
611620
(newu0 === nothing || isempty(newu0)) && return newu0, newp
612-
initdata = prob.f.initialization_data
613621
initdata === nothing && return newu0, newp
614622
meta = initdata.metadata
615623
meta isa InitializationMetadata || return newu0, newp

src/systems/problem_utils.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ properly.
769769
770770
$(TYPEDFIELDS)
771771
"""
772-
struct InitializationMetadata{R <: ReconstructInitializeprob, SIU}
772+
struct InitializationMetadata{R <: ReconstructInitializeprob, GIU, SIU}
773773
"""
774774
The `u0map` used to construct the initialization.
775775
"""
@@ -796,6 +796,11 @@ struct InitializationMetadata{R <: ReconstructInitializeprob, SIU}
796796
"""
797797
oop_reconstruct_u0_p::R
798798
"""
799+
A function which takes the parameter object of the problem and returns
800+
`Initial.(unknowns(sys))`.
801+
"""
802+
get_initial_unknowns::GIU
803+
"""
799804
A function which takes the `u0` of the problem and sets
800805
`Initial.(unknowns(sys))`.
801806
"""
@@ -843,7 +848,7 @@ function maybe_build_initialization_problem(
843848
meta = InitializationMetadata(
844849
u0map, pmap, guesses, Vector{Equation}(initialization_eqs),
845850
use_scc, ReconstructInitializeprob(sys, initializeprob.f.sys),
846-
setp(sys, Initial.(unknowns(sys))))
851+
getp(sys, Initial.(unknowns(sys))), setp(sys, Initial.(unknowns(sys))))
847852

848853
if is_time_dependent(sys)
849854
all_init_syms = Set(all_symbols(initializeprob))

0 commit comments

Comments
 (0)