Skip to content

Commit df5afa9

Browse files
fix: propagate guesses for algebraic variables instead of Initial(x)
1 parent cc2f068 commit df5afa9

File tree

2 files changed

+57
-14
lines changed

2 files changed

+57
-14
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -653,20 +653,16 @@ end
653653
function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw...)
654654
supports_initialization(sys) || return prob
655655
initdata = prob.f.initialization_data
656-
initdata === nothing && return prob
656+
initdata isa SciMLBase.OverrideInitData || return prob
657+
meta = initdata.metadata
658+
meta isa InitializationMetadata || return prob
659+
meta.get_updated_u0 === nothing && return prob
657660

658661
u0 = state_values(prob)
659662
u0 === nothing && return prob
660663

661664
p = parameter_values(prob)
662665
t0 = is_time_dependent(prob) ? current_time(prob) : nothing
663-
meta = initdata.metadata
664-
665-
getter = if meta isa InitializationMetadata
666-
meta.get_initial_unknowns
667-
else
668-
getu(sys, Initial.(unknowns(sys)))
669-
end
670666

671667
if p isa MTKParameters
672668
buffer = p.initials
@@ -681,7 +677,8 @@ function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw..
681677
else
682678
T = StaticArrays.similar_type(u0)
683679
end
684-
return remake(prob; u0 = T(getter(p)))
680+
681+
return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob)))
685682
end
686683

687684
"""

src/systems/problem_utils.jl

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ properly.
769769
770770
$(TYPEDFIELDS)
771771
"""
772-
struct InitializationMetadata{R <: ReconstructInitializeprob, GIU, SIU}
772+
struct InitializationMetadata{R <: ReconstructInitializeprob, GUU, SIU}
773773
"""
774774
The `u0map` used to construct the initialization.
775775
"""
@@ -796,17 +796,58 @@ struct InitializationMetadata{R <: ReconstructInitializeprob, GIU, SIU}
796796
"""
797797
oop_reconstruct_u0_p::R
798798
"""
799-
A function which takes the parameter object of the problem and returns
800-
`Initial.(unknowns(sys))`.
799+
A function which takes `(prob, initializeprob)` and return the `u0` to use for the problem.
801800
"""
802-
get_initial_unknowns::GIU
801+
get_updated_u0::GUU
803802
"""
804803
A function which takes the `u0` of the problem and sets
805804
`Initial.(unknowns(sys))`.
806805
"""
807806
set_initial_unknowns!::SIU
808807
end
809808

809+
"""
810+
$(TYPEDEF)
811+
812+
A callable struct to use as the `get_updated_u0` field of `InitializationMetadata`.
813+
Returns the value of `Initial.(unknowns(sys))`, except with algebraic variables replaced
814+
by their guess values in the initialization problem.
815+
816+
# Fields
817+
818+
$(TYPEDFIELDS)
819+
"""
820+
struct GetUpdatedU0{GA, GIU}
821+
"""
822+
Mask with length `length(unknowns(sys))` denoting indices of algebraic variables.
823+
"""
824+
algevars::BitVector
825+
"""
826+
Function which returns the values of algebraic variables in `initializeprob`, in the
827+
order the algebraic variables occur in `unknowns(sys)`.
828+
"""
829+
get_algevars::GA
830+
"""
831+
Function which returns `Initial.(unknowns(sys))` as a `Vector`.
832+
"""
833+
get_initial_unknowns::GIU
834+
end
835+
836+
function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem)
837+
algevaridxs = BitVector(is_alg_equation.(equations(sys)))
838+
algevars = unknowns(sys)[algevaridxs]
839+
get_algevars = getu(initsys, algevars)
840+
get_initial_unknowns = getu(sys, Initial.(unknowns(sys)))
841+
return GetUpdatedU0(algevaridxs, get_algevars, get_initial_unknowns)
842+
end
843+
844+
function (guu::GetUpdatedU0)(prob, initprob)
845+
buffer = guu.get_initial_unknowns(prob)
846+
algebuf = view(buffer, guu.algevars)
847+
copyto!(algebuf, guu.get_algevars(initprob))
848+
return buffer
849+
end
850+
810851
"""
811852
$(TYPEDSIGNATURES)
812853
@@ -845,10 +886,15 @@ function maybe_build_initialization_problem(
845886
end
846887
initializeprob = remake(initializeprob; p = initp)
847888

889+
get_initial_unknowns = if is_time_dependent(sys)
890+
GetUpdatedU0(sys, initializeprob.f.sys)
891+
else
892+
nothing
893+
end
848894
meta = InitializationMetadata(
849895
u0map, pmap, guesses, Vector{Equation}(initialization_eqs),
850896
use_scc, ReconstructInitializeprob(sys, initializeprob.f.sys),
851-
getp(sys, Initial.(unknowns(sys))), setp(sys, Initial.(unknowns(sys))))
897+
get_initial_unknowns, setp(sys, Initial.(unknowns(sys))))
852898

853899
if is_time_dependent(sys)
854900
all_init_syms = Set(all_symbols(initializeprob))

0 commit comments

Comments
 (0)