Skip to content

Commit aff87fa

Browse files
fix: handle immutable buffers in initialization
1 parent b69d79e commit aff87fa

File tree

2 files changed

+162
-71
lines changed

2 files changed

+162
-71
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,7 @@ function SciMLBase.remake_initialization_data(
513513
length(oldinitprob.f.resid_prototype), new_initu0, new_initp))
514514
end
515515
initprob = remake(oldinitprob; f = newf, u0 = new_initu0, p = new_initp)
516-
return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!,
517-
oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap; metadata = oldinitdata.metadata)
516+
return @set oldinitdata.initializeprob = initprob
518517
end
519518

520519
dvs = unknowns(sys)
@@ -627,7 +626,7 @@ function SciMLBase.late_binding_update_u0_p(
627626
if length(newu0) != length(prob.u0)
628627
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
629628
end
630-
meta.set_initial_unknowns!(newp, newu0)
629+
newp = meta.set_initial_unknowns!(newp, newu0)
631630
return newu0, newp
632631
end
633632

src/systems/problem_utils.jl

Lines changed: 160 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -491,32 +491,6 @@ function scalarize_varmap!(varmap::AbstractDict)
491491
return varmap
492492
end
493493

494-
struct GetUpdatedMTKParameters{G, S}
495-
# `getu` functor which gets parameters that are unknowns during initialization
496-
getpunknowns::G
497-
# `setu` functor which returns a modified MTKParameters using those parameters
498-
setpunknowns::S
499-
end
500-
501-
function (f::GetUpdatedMTKParameters)(prob, initializesol)
502-
p = parameter_values(prob)
503-
p === nothing && return nothing
504-
mtkp = copy(p)
505-
f.setpunknowns(mtkp, f.getpunknowns(initializesol))
506-
mtkp
507-
end
508-
509-
struct UpdateInitializeprob{G, S}
510-
# `getu` functor which gets all values from prob
511-
getvals::G
512-
# `setu` functor which updates initializeprob with values
513-
setvals::S
514-
end
515-
516-
function (f::UpdateInitializeprob)(initializeprob, prob)
517-
f.setvals(initializeprob, f.getvals(prob))
518-
end
519-
520494
function get_temporary_value(p, floatT = Float64)
521495
stype = symtype(unwrap(p))
522496
return if stype == Real
@@ -669,48 +643,89 @@ function concrete_getu(indp, syms::AbstractVector)
669643
return Base.Fix1(reduce, vcat) getu(indp, split_syms)
670644
end
671645

646+
"""
647+
$(TYPEDSIGNATURES)
648+
649+
Given a source system `srcsys` and destination system `dstsys`, return a function that
650+
takes a value provider of `srcsys` and a value provider of `dstsys` and returns the
651+
`MTKParameters` object of the latter with values from the former.
652+
653+
# Keyword Arguments
654+
- `initials`: Whether to include the `Initial` parameters of `dstsys` among the values
655+
to be transferred.
656+
- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
657+
"""
658+
function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem;
659+
initials = false, unwrap_initials = false, p_constructor = identity)
660+
# if we call `getu` on this (and it were able to handle empty tuples) we get the
661+
# fields of `MTKParameters` except caches.
662+
syms = reorder_parameters(
663+
dstsys, parameters(dstsys; initial_parameters = initials); flatten = false)
664+
# `dstsys` is an initialization system, do basically everything is a tunable
665+
# and tunables are a mix of different types in `srcsys`. No initials. Constants
666+
# are going to be constants in `srcsys`, as are `nonnumeric`.
667+
668+
# `syms[1]` is always the tunables because `srcsys` will have initials.
669+
tunable_syms = syms[1]
670+
tunable_getter = p_constructor concrete_getu(srcsys, tunable_syms)
671+
rest_getters = map(Base.tail(Base.tail(syms))) do buf
672+
if buf == ()
673+
return Returns(())
674+
else
675+
return Base.Fix1(broadcast, p_constructor) getu(srcsys, buf)
676+
end
677+
end
678+
initials_getter = if initials
679+
initsyms = Vector{Any}(syms[2])
680+
allsyms = Set(all_symbols(srcsys))
681+
if unwrap_initials
682+
for i in eachindex(initsyms)
683+
sym = initsyms[i]
684+
innersym = if operation(sym) === getindex
685+
sym, idxs... = arguments(sym)
686+
only(arguments(sym))[idxs...]
687+
else
688+
only(arguments(sym))
689+
end
690+
if innersym in allsyms
691+
initsyms[i] = innersym
692+
end
693+
end
694+
end
695+
p_constructor concrete_getu(srcsys, initsyms)
696+
else
697+
Returns(SizedVector{0, Float64}())
698+
end
699+
getters = (tunable_getter, initials_getter, rest_getters...)
700+
getter = let getters = getters
701+
function _getter(valp, initprob)
702+
oldcache = parameter_values(initprob).caches
703+
MTKParameters(getters[1](valp), getters[2](valp), getters[3](valp),
704+
getters[4](valp), getters[5](valp), oldcache isa Tuple{} ? () :
705+
copy.(oldcache))
706+
end
707+
end
708+
709+
return getter
710+
end
711+
672712
"""
673713
$(TYPEDSIGNATURES)
674714
675715
Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of `dstsys`
676716
with values from `srcsys`.
677717
"""
678718
function ReconstructInitializeprob(
679-
srcsys::AbstractSystem, dstsys::AbstractSystem)
719+
srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity)
680720
@assert is_initializesystem(dstsys)
681-
ugetter = getu(srcsys, unknowns(dstsys))
721+
ugetter = u0_constructor getu(srcsys, unknowns(dstsys))
682722
if is_split(dstsys)
683-
# if we call `getu` on this (and it were able to handle empty tuples) we get the
684-
# fields of `MTKParameters` except caches.
685-
syms = reorder_parameters(dstsys, parameters(dstsys); flatten = false)
686-
# `dstsys` is an initialization system, do basically everything is a tunable
687-
# and tunables are a mix of different types in `srcsys`. No initials. Constants
688-
# are going to be constants in `srcsys`, as are `nonnumeric`.
689-
690-
# `syms[1]` is always the tunables because `srcsys` will have initials.
691-
tunable_syms = syms[1]
692-
tunable_getter = concrete_getu(srcsys, tunable_syms)
693-
rest_getters = map(Base.tail(Base.tail(syms))) do buf
694-
if buf == ()
695-
return Returns(())
696-
else
697-
return getu(srcsys, buf)
698-
end
699-
end
700-
getters = (tunable_getter, Returns(SizedVector{0, Float64}()), rest_getters...)
701-
pgetter = let getters = getters
702-
function _getter(valp, initprob)
703-
oldcache = parameter_values(initprob).caches
704-
MTKParameters(getters[1](valp), getters[2](valp), getters[3](valp),
705-
getters[4](valp), getters[5](valp), oldcache isa Tuple{} ? () :
706-
copy.(oldcache))
707-
end
708-
end
723+
pgetter = get_mtkparameters_reconstructor(srcsys, dstsys; p_constructor)
709724
else
710725
syms = parameters(dstsys)
711-
pgetter = let inner = concrete_getu(srcsys, syms)
726+
pgetter = let inner = concrete_getu(srcsys, syms), p_constructor = p_constructor
712727
function _getter2(valp, initprob)
713-
inner(valp)
728+
p_constructor(inner(valp))
714729
end
715730
end
716731
end
@@ -763,6 +778,54 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
763778
return u0, newp
764779
end
765780

781+
"""
782+
$(TYPEDSIGNATURES)
783+
784+
Given `sys` and its corresponding initialization system `initsys`, return the
785+
`initializeprobpmap` function in `OverrideInitData` for the systems.
786+
"""
787+
function construct_initializeprobpmap(
788+
sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity)
789+
@assert is_initializesystem(initsys)
790+
if is_split(sys)
791+
return let getter = get_mtkparameters_reconstructor(
792+
initsys, sys; initials = true, p_constructor)
793+
function initprobpmap_split(prob, initsol)
794+
getter(initsol, prob)
795+
end
796+
end
797+
else
798+
return let getter = getu(initsys, parameters(sys; initial_parameters = true)),
799+
p_constructor = p_constructor
800+
801+
function initprobpmap_nosplit(prob, initsol)
802+
return p_constructor(getter(initsol))
803+
end
804+
end
805+
end
806+
end
807+
808+
function get_scimlfn(valp)
809+
valp isa SciMLBase.AbstractSciMLFunction && return valp
810+
if hasmethod(symbolic_container, Tuple{typeof(valp)}) &&
811+
(sc = symbolic_container(valp)) !== valp
812+
return get_scimlfn(sc)
813+
end
814+
throw(ArgumentError("SciMLFunction not found. This should never happen."))
815+
end
816+
817+
"""
818+
$(TYPEDSIGNATURES)
819+
820+
A function to be used as `update_initializeprob!` in `OverrideInitData`. Requires
821+
`is_update_oop = Val{true}` to be passed to `update_initializeprob!`.
822+
"""
823+
function update_initializeprob!(initprob, prob)
824+
p = get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.getter(
825+
prob, initprob)
826+
return remake(initprob; p)
827+
end
828+
766829
"""
767830
$(TYPEDEF)
768831
@@ -804,8 +867,8 @@ struct InitializationMetadata{R <: ReconstructInitializeprob, GUU, SIU}
804867
"""
805868
get_updated_u0::GUU
806869
"""
807-
A function which takes the `u0` of the problem and sets
808-
`Initial.(unknowns(sys))`.
870+
A function which takes parameter object and `u0` of the problem and sets
871+
`Initial.(unknowns(sys))` in the former, returning the updated parameter object.
809872
"""
810873
set_initial_unknowns!::SIU
811874
end
@@ -856,6 +919,38 @@ function (guu::GetUpdatedU0)(prob, initprob)
856919
return buffer
857920
end
858921

922+
struct SetInitialUnknowns{S}
923+
setter!::S
924+
end
925+
926+
function SetInitialUnknowns(sys::AbstractSystem)
927+
return SetInitialUnknowns(setu(sys, Initial.(unknowns(sys))))
928+
end
929+
930+
function (siu::SetInitialUnknowns)(p::MTKParameters, u0)
931+
if ArrayInterface.ismutable(p.initials)
932+
siu.setter!(p, u0)
933+
else
934+
originalT = similar_type(p.initials)
935+
@set! p.initials = MVector{length(p.initials), eltype(p.initials)}(p.initials)
936+
siu.setter!(p, u0)
937+
@set! p.initials = originalT(p.initials)
938+
end
939+
return p
940+
end
941+
942+
function (siu::SetInitialUnknowns)(p::Vector, u0)
943+
if ArrayInterface.ismutable(p)
944+
siu.setter!(p, u0)
945+
else
946+
originalT = similar_type(p)
947+
p = MVector{length(p), eltype(p)}(p)
948+
siu.setter!(p, u0)
949+
p = originalT(p)
950+
end
951+
return p
952+
end
953+
859954
"""
860955
$(TYPEDSIGNATURES)
861956
@@ -913,8 +1008,9 @@ function maybe_build_initialization_problem(
9131008
end
9141009
meta = InitializationMetadata(
9151010
u0map, pmap, guesses, Vector{Equation}(initialization_eqs),
916-
use_scc, ReconstructInitializeprob(sys, initializeprob.f.sys),
917-
get_initial_unknowns, setp(sys, Initial.(unknowns(sys))))
1011+
use_scc, ReconstructInitializeprob(
1012+
sys, initializeprob.f.sys; u0_constructor, p_constructor),
1013+
get_initial_unknowns, SetInitialUnknowns(sys))
9181014

9191015
if is_time_dependent(sys)
9201016
all_init_syms = Set(all_symbols(initializeprob))
@@ -930,20 +1026,16 @@ function maybe_build_initialization_problem(
9301026
if initializeprobmap === nothing && isempty(punknowns)
9311027
initializeprobpmap = nothing
9321028
else
933-
allsyms = all_symbols(initializeprob)
934-
initdvs = filter(x -> any(isequal(x), allsyms), unknowns(sys))
935-
getpunknowns = getu(initializeprob, [punknowns; initdvs])
936-
setpunknowns = setp(sys, [punknowns; Initial.(initdvs)])
937-
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
1029+
initializeprobpmap = construct_initializeprobpmap(
1030+
sys, initializeprob.f.sys; p_constructor)
9381031
end
9391032

9401033
reqd_syms = parameter_symbols(initializeprob)
9411034
# we still want the `initialization_data` because it helps with `remake`
9421035
if initializeprobmap === nothing && initializeprobpmap === nothing
9431036
update_initializeprob! = nothing
9441037
else
945-
update_initializeprob! = UpdateInitializeprob(
946-
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
1038+
update_initializeprob! = ModelingToolkit.update_initializeprob!
9471039
end
9481040

9491041
for p in punknowns
@@ -967,7 +1059,7 @@ function maybe_build_initialization_problem(
9671059
return (;
9681060
initialization_data = SciMLBase.OverrideInitData(
9691061
initializeprob, update_initializeprob!, initializeprobmap,
970-
initializeprobpmap; metadata = meta))
1062+
initializeprobpmap; metadata = meta, is_update_oop = Val{true}))
9711063
end
9721064

9731065
"""

0 commit comments

Comments
 (0)