Skip to content

Commit 4a03482

Browse files
feat: store InitializationMetadata in OverrideInitData
1 parent 0046f63 commit 4a03482

File tree

2 files changed

+77
-34
lines changed

2 files changed

+77
-34
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -481,22 +481,19 @@ function SciMLBase.remake_initialization_data(
481481
if u0 === missing && p === missing
482482
return odefn.initialization_data
483483
end
484+
485+
oldinitdata = odefn.initialization_data
486+
484487
if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair)
485-
oldinitdata = odefn.initialization_data
486488
oldinitdata === nothing && return nothing
487489

488490
oldinitprob = oldinitdata.initializeprob
489491
oldinitprob === nothing && return nothing
490-
if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem)
491-
return oldinitdata
492-
end
493-
oldinitsys = oldinitprob.f.sys
494-
meta = get_metadata(oldinitsys)
495-
if meta isa InitializationSystemMetadata && meta.oop_reconstruct_u0_p !== nothing
496-
reconstruct_fn = meta.oop_reconstruct_u0_p
497-
else
498-
reconstruct_fn = ReconstructInitializeprob(sys, oldinitsys)
499-
end
492+
493+
meta = oldinitdata.metadata
494+
meta isa InitializationMetadata || return oldinitdata
495+
496+
reconstruct_fn = meta.oop_reconstruct_u0_p
500497
# the history function doesn't matter because `reconstruct_fn` is only going to
501498
# update the values of parameters, which aren't time dependent. The reason it
502499
# is called is because `Initial` parameters are calculated from the corresponding
@@ -507,16 +504,15 @@ function SciMLBase.remake_initialization_data(
507504
if oldinitprob.f.resid_prototype === nothing
508505
newf = oldinitprob.f
509506
else
510-
newf = NonlinearFunction{
511-
SciMLBase.isinplace(oldinitprob.f), SciMLBase.specialization(oldinitprob.f)}(
512-
oldinitprob.f;
507+
newf = remake(oldinitprob.f;
513508
resid_prototype = calculate_resid_prototype(
514509
length(oldinitprob.f.resid_prototype), new_initu0, new_initp))
515510
end
516511
initprob = remake(oldinitprob; f = newf, u0 = new_initu0, p = new_initp)
517512
return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!,
518-
oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap)
513+
oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap; metadata = oldinitdata.metadata)
519514
end
515+
520516
dvs = unknowns(sys)
521517
ps = parameters(sys)
522518
u0map = to_varmap(u0, dvs)
@@ -530,16 +526,13 @@ function SciMLBase.remake_initialization_data(
530526
use_scc = true
531527
initialization_eqs = Equation[]
532528

533-
if SciMLBase.has_initializeprob(odefn)
534-
oldsys = odefn.initialization_data.initializeprob.f.sys
535-
meta = get_metadata(oldsys)
536-
if meta isa InitializationSystemMetadata
537-
u0map = merge(meta.u0map, u0map)
538-
pmap = merge(meta.pmap, pmap)
539-
merge!(guesses, meta.additional_guesses)
540-
use_scc = get(meta.extra_metadata, :use_scc, true)
541-
initialization_eqs = meta.additional_initialization_eqs
542-
end
529+
if oldinitdata !== nothing && oldinitdata.metadata isa InitializationMetadata
530+
meta = oldinitdata.metadata
531+
u0map = merge(meta.u0map, u0map)
532+
pmap = merge(meta.pmap, pmap)
533+
merge!(guesses, meta.guesses)
534+
use_scc = meta.use_scc
535+
initialization_eqs = meta.additional_initialization_eqs
543536
else
544537
# there is no initializeprob, so the original problem construction
545538
# had no solvable parameters and had the differential variables
@@ -598,19 +591,22 @@ function SciMLBase.late_binding_update_u0_p(
598591
if !(eltype(u0) <: Pair)
599592
# if `p` is not provided or is symbolic
600593
p === missing || eltype(p) <: Pair || return newu0, newp
601-
newu0 === nothing && return newu0, newp
602-
all(is_parameter(sys, Initial(x)) for x in unknowns(sys)) || return newu0, newp
594+
(newu0 === nothing || isempty(newu0)) && return newu0, newp
595+
initdata = prob.f.initialization_data
596+
initdata === nothing && return newu0, newp
597+
meta = initdata.metadata
598+
meta isa InitializationMetadata || return newu0, newp
603599
newp = p === missing ? copy(newp) : newp
604600
initials, repack, alias = SciMLStructures.canonicalize(
605601
SciMLStructures.Initials(), newp)
606602
if eltype(initials) != eltype(newu0)
607603
initials = DiffEqBase.promote_u0(initials, newu0, t0)
608604
newp = repack(initials)
609605
end
610-
if length(newu0) != length(unknowns(sys))
611-
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(unknowns(sys)))). Got $(typeof(newu0)) of length $(length(newu0))"))
606+
if length(newu0) != length(prob.u0)
607+
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
612608
end
613-
setp(sys, Initial.(unknowns(sys)))(newp, newu0)
609+
meta.set_initial_unknowns!(newp, newu0)
614610
return newu0, newp
615611
end
616612

src/systems/problem_utils.jl

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,49 @@ function build_operating_point!(sys::AbstractSystem,
636636
return op, missing_unknowns, missing_pars
637637
end
638638

639+
"""
640+
$(TYPEDEF)
641+
642+
Metadata attached to `OverrideInitData` used in `remake` hooks for handling initialization
643+
properly.
644+
645+
# Fields
646+
647+
$(TYPEDFIELDS)
648+
"""
649+
struct InitializationMetadata{R <: ReconstructInitializeprob, SIU}
650+
"""
651+
The `u0map` used to construct the initialization.
652+
"""
653+
u0map::Dict{Any, Any}
654+
"""
655+
The `pmap` used to construct the initialization.
656+
"""
657+
pmap::Dict{Any, Any}
658+
"""
659+
The `guesses` used to construct the initialization.
660+
"""
661+
guesses::Dict{Any, Any}
662+
"""
663+
The `initialization_eqs` in addition to those of the system that were used to construct
664+
the initialization.
665+
"""
666+
additional_initialization_eqs::Vector{Equation}
667+
"""
668+
Whether to use `SCCNonlinearProblem` if possible.
669+
"""
670+
use_scc::Bool
671+
"""
672+
`ReconstructInitializeprob` for this initialization problem.
673+
"""
674+
oop_reconstruct_u0_p::R
675+
"""
676+
A function which takes the `u0` of the problem and sets
677+
`Initial.(unknowns(sys))`.
678+
"""
679+
set_initial_unknowns!::SIU
680+
end
681+
639682
"""
640683
$(TYPEDSIGNATURES)
641684
@@ -648,16 +691,20 @@ All other keyword arguments are forwarded to `InitializationProblem`.
648691
"""
649692
function maybe_build_initialization_problem(
650693
sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs,
651-
guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity, kwargs...)
694+
guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity,
695+
initialization_eqs = [], use_scc = true, kwargs...)
652696
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))
653697

654698
if t === nothing && is_time_dependent(sys)
655699
t = 0.0
656700
end
657701

658702
initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}(
659-
sys, t, u0map, pmap; guesses, kwargs...)
660-
meta = get_metadata(initializeprob.f.sys)
703+
sys, t, u0map, pmap; guesses, initialization_eqs, use_scc, kwargs...)
704+
meta = InitializationMetadata(
705+
u0map, pmap, guesses, Equation[get_initialization_eqs(sys); initialization_eqs],
706+
use_scc, ReconstructInitializeprob(sys, initializeprob.f.sys),
707+
setp(sys, Initial.(unknowns(sys))))
661708

662709
if is_time_dependent(sys)
663710
all_init_syms = Set(all_symbols(initializeprob))
@@ -709,7 +756,7 @@ function maybe_build_initialization_problem(
709756
return (;
710757
initialization_data = SciMLBase.OverrideInitData(
711758
initializeprob, update_initializeprob!, initializeprobmap,
712-
initializeprobpmap))
759+
initializeprobpmap; metadata = meta))
713760
end
714761

715762
"""

0 commit comments

Comments
 (0)