Skip to content

Commit 21325ca

Browse files
feat: store InitializationMetadata in OverrideInitData
1 parent 4a53fba commit 21325ca

File tree

2 files changed

+77
-35
lines changed

2 files changed

+77
-35
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -481,22 +481,19 @@ function SciMLBase.remake_initialization_data(
481481
if u0 === missing && p === missing
482482
return odefn.initialization_data
483483
end
484+
485+
oldinitdata = odefn.initialization_data
486+
484487
if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair)
485-
oldinitdata = odefn.initialization_data
486488
oldinitdata === nothing && return nothing
487489

488490
oldinitprob = oldinitdata.initializeprob
489491
oldinitprob === nothing && return nothing
490-
if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem)
491-
return oldinitdata
492-
end
493-
oldinitsys = oldinitprob.f.sys
494-
meta = get_metadata(oldinitsys)
495-
if meta isa InitializationSystemMetadata && meta.oop_reconstruct_u0_p !== nothing
496-
reconstruct_fn = meta.oop_reconstruct_u0_p
497-
else
498-
reconstruct_fn = ReconstructInitializeprob(sys, oldinitsys)
499-
end
492+
493+
meta = oldinitdata.metadata
494+
meta isa InitializationMetadata || return oldinitdata
495+
496+
reconstruct_fn = meta.oop_reconstruct_u0_p
500497
# the history function doesn't matter because `reconstruct_fn` is only going to
501498
# update the values of parameters, which aren't time dependent. The reason it
502499
# is called is because `Initial` parameters are calculated from the corresponding
@@ -507,16 +504,15 @@ function SciMLBase.remake_initialization_data(
507504
if oldinitprob.f.resid_prototype === nothing
508505
newf = oldinitprob.f
509506
else
510-
newf = NonlinearFunction{
511-
SciMLBase.isinplace(oldinitprob.f), SciMLBase.specialization(oldinitprob.f)}(
512-
oldinitprob.f;
507+
newf = remake(oldinitprob.f;
513508
resid_prototype = calculate_resid_prototype(
514509
length(oldinitprob.f.resid_prototype), new_initu0, new_initp))
515510
end
516511
initprob = remake(oldinitprob; f = newf, u0 = new_initu0, p = new_initp)
517512
return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!,
518-
oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap)
513+
oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap; metadata = oldinitdata.metadata)
519514
end
515+
520516
dvs = unknowns(sys)
521517
ps = parameters(sys)
522518
u0map = to_varmap(u0, dvs)
@@ -530,16 +526,13 @@ function SciMLBase.remake_initialization_data(
530526
use_scc = true
531527
initialization_eqs = Equation[]
532528

533-
if SciMLBase.has_initializeprob(odefn)
534-
oldsys = odefn.initialization_data.initializeprob.f.sys
535-
meta = get_metadata(oldsys)
536-
if meta isa InitializationSystemMetadata
537-
u0map = merge(meta.u0map, u0map)
538-
pmap = merge(meta.pmap, pmap)
539-
merge!(guesses, meta.additional_guesses)
540-
use_scc = get(meta.extra_metadata, :use_scc, true)
541-
initialization_eqs = meta.additional_initialization_eqs
542-
end
529+
if oldinitdata !== nothing && oldinitdata.metadata isa InitializationMetadata
530+
meta = oldinitdata.metadata
531+
u0map = merge(meta.u0map, u0map)
532+
pmap = merge(meta.pmap, pmap)
533+
merge!(guesses, meta.guesses)
534+
use_scc = meta.use_scc
535+
initialization_eqs = meta.additional_initialization_eqs
543536
else
544537
# there is no initializeprob, so the original problem construction
545538
# had no solvable parameters and had the differential variables
@@ -600,19 +593,22 @@ function SciMLBase.late_binding_update_u0_p(
600593
if !(eltype(u0) <: Pair)
601594
# if `p` is not provided or is symbolic
602595
p === missing || eltype(p) <: Pair || return newu0, newp
603-
newu0 === nothing && return newu0, newp
604-
all(is_parameter(sys, Initial(x)) for x in unknowns(sys)) || return newu0, newp
596+
(newu0 === nothing || isempty(newu0)) && return newu0, newp
597+
initdata = prob.f.initialization_data
598+
initdata === nothing && return newu0, newp
599+
meta = initdata.metadata
600+
meta isa InitializationMetadata || return newu0, newp
605601
newp = p === missing ? copy(newp) : newp
606602
initials, repack, alias = SciMLStructures.canonicalize(
607603
SciMLStructures.Initials(), newp)
608604
if eltype(initials) != eltype(newu0)
609605
initials = DiffEqBase.promote_u0(initials, newu0, t0)
610606
newp = repack(initials)
611607
end
612-
if length(newu0) != length(unknowns(sys))
613-
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(unknowns(sys)))). Got $(typeof(newu0)) of length $(length(newu0))"))
608+
if length(newu0) != length(prob.u0)
609+
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
614610
end
615-
setp(sys, Initial.(unknowns(sys)))(newp, newu0)
611+
meta.set_initial_unknowns!(newp, newu0)
616612
return newu0, newp
617613
end
618614

src/systems/problem_utils.jl

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,49 @@ function build_operating_point!(sys::AbstractSystem,
620620
return op, missing_unknowns, missing_pars
621621
end
622622

623+
"""
624+
$(TYPEDEF)
625+
626+
Metadata attached to `OverrideInitData` used in `remake` hooks for handling initialization
627+
properly.
628+
629+
# Fields
630+
631+
$(TYPEDFIELDS)
632+
"""
633+
struct InitializationMetadata{R <: ReconstructInitializeprob, SIU}
634+
"""
635+
The `u0map` used to construct the initialization.
636+
"""
637+
u0map::Dict{Any, Any}
638+
"""
639+
The `pmap` used to construct the initialization.
640+
"""
641+
pmap::Dict{Any, Any}
642+
"""
643+
The `guesses` used to construct the initialization.
644+
"""
645+
guesses::Dict{Any, Any}
646+
"""
647+
The `initialization_eqs` in addition to those of the system that were used to construct
648+
the initialization.
649+
"""
650+
additional_initialization_eqs::Vector{Equation}
651+
"""
652+
Whether to use `SCCNonlinearProblem` if possible.
653+
"""
654+
use_scc::Bool
655+
"""
656+
`ReconstructInitializeprob` for this initialization problem.
657+
"""
658+
oop_reconstruct_u0_p::R
659+
"""
660+
A function which takes the `u0` of the problem and sets
661+
`Initial.(unknowns(sys))`.
662+
"""
663+
set_initial_unknowns!::SIU
664+
end
665+
623666
"""
624667
$(TYPEDSIGNATURES)
625668
@@ -632,16 +675,16 @@ All other keyword arguments are forwarded to `InitializationProblem`.
632675
"""
633676
function maybe_build_initialization_problem(
634677
sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs,
635-
guesses, missing_unknowns; implicit_dae = false,
636-
u0_constructor = identity, floatT = Float64, kwargs...)
678+
guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity,
679+
floatT = Float64, initialization_eqs = [], use_scc = true, kwargs...)
637680
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))
638681

639682
if t === nothing && is_time_dependent(sys)
640683
t = zero(floatT)
641684
end
642685

643686
initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}(
644-
sys, t, u0map, pmap; guesses, kwargs...)
687+
sys, t, u0map, pmap; guesses, initialization_eqs, use_scc, kwargs...)
645688
if state_values(initializeprob) !== nothing
646689
initializeprob = remake(initializeprob; u0 = floatT.(state_values(initializeprob)))
647690
end
@@ -658,7 +701,10 @@ function maybe_build_initialization_problem(
658701
end
659702
initializeprob = remake(initializeprob; p = initp)
660703

661-
meta = get_metadata(initializeprob.f.sys)
704+
meta = InitializationMetadata(
705+
u0map, pmap, guesses, Equation[get_initialization_eqs(sys); initialization_eqs],
706+
use_scc, ReconstructInitializeprob(sys, initializeprob.f.sys),
707+
setp(sys, Initial.(unknowns(sys))))
662708

663709
if is_time_dependent(sys)
664710
all_init_syms = Set(all_symbols(initializeprob))
@@ -710,7 +756,7 @@ function maybe_build_initialization_problem(
710756
return (;
711757
initialization_data = SciMLBase.OverrideInitData(
712758
initializeprob, update_initializeprob!, initializeprobmap,
713-
initializeprobpmap))
759+
initializeprobpmap; metadata = meta))
714760
end
715761

716762
"""

0 commit comments

Comments
 (0)