Skip to content

Commit ae505a3

Browse files
fix: make ReconstructInitializeprob type-stable
1 parent 125fb11 commit ae505a3

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

src/systems/problem_utils.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,138 @@ function build_operating_point!(sys::AbstractSystem,
636636
return op, missing_unknowns, missing_pars
637637
end
638638

639+
"""
640+
$(TYPEDEF)
641+
642+
A callable struct used to reconstruct the `u0` and `p` of the initialization problem
643+
with promoted types.
644+
645+
# Fields
646+
647+
$(TYPEDFIELDS)
648+
"""
649+
struct ReconstructInitializeprob{G}
650+
"""
651+
A function which when called on the original problem returns the parameter object of
652+
the initialization problem.
653+
"""
654+
getter::G
655+
end
656+
657+
"""
658+
$(TYPEDSIGNATURES)
659+
660+
Given an index provider `indp` and a vector of symbols `syms` return a type-stable getter
661+
function by splitting `syms` into contiguous buffers where the getter of each buffer
662+
is type-stable and constructing a function that calls and concatenates the results.
663+
"""
664+
function concrete_getu(indp, syms::AbstractVector)
665+
# a list of contiguous buffer
666+
split_syms = [Any[syms[1]]]
667+
# the type of the getter of the last buffer
668+
current = typeof(getu(indp, syms[1]))
669+
for sym in syms[2:end]
670+
getter = getu(indp, sym)
671+
if typeof(getter) != current
672+
# if types don't match, build a new buffer
673+
push!(split_syms, [])
674+
current = typeof(getter)
675+
end
676+
push!(split_syms[end], sym)
677+
end
678+
split_syms = Tuple(split_syms)
679+
# the getter is now type-stable, and we can vcat it to get the full buffer
680+
return Base.Fix1(reduce, vcat) getu(indp, split_syms)
681+
end
682+
683+
"""
684+
$(TYPEDSIGNATURES)
685+
686+
Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of `dstsys`
687+
with values from `srcsys`.
688+
"""
689+
function ReconstructInitializeprob(
690+
srcsys::AbstractSystem, dstsys::AbstractSystem)
691+
@assert is_initializesystem(dstsys)
692+
if is_split(dstsys)
693+
# if we call `getu` on this (and it were able to handle empty tuples) we get the
694+
# fields of `MTKParameters` except caches.
695+
syms = reorder_parameters(dstsys, parameters(dstsys); flatten = false)
696+
# `dstsys` is an initialization system, do basically everything is a tunable
697+
# and tunables are a mix of different types in `srcsys`. No initials. Constants
698+
# are going to be constants in `srcsys`, as are `nonnumeric`.
699+
700+
# `syms[1]` is always the tunables because `srcsys` will have initials.
701+
tunable_syms = syms[1]
702+
tunable_getter = concrete_getu(srcsys, tunable_syms)
703+
rest_getters = map(Base.tail(Base.tail(syms))) do buf
704+
if buf == ()
705+
return Returns(())
706+
else
707+
return getu(srcsys, buf)
708+
end
709+
end
710+
getters = (tunable_getter, Returns(SizedVector{0, Float64}()), rest_getters...)
711+
getter = let getters = getters
712+
function _getter(valp)
713+
MTKParameters(getters[1](valp), getters[2](valp), getters[3](valp),
714+
getters[4](valp), getters[5](valp), ())
715+
end
716+
end
717+
else
718+
syms = parameters(dstsys)
719+
getter = concrete_getu(srcsys, syms)
720+
end
721+
return ReconstructInitializeprob(getter)
722+
end
723+
724+
"""
725+
$(TYPEDSIGNATURES)
726+
727+
Copy values from `srcvalp` to `dstvalp`. Returns the new `u0` and `p`.
728+
"""
729+
function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
730+
# copy parameters
731+
newp = rip.getter(srcvalp)
732+
# no `u0`, so no type-promotion
733+
if state_values(dstvalp) === nothing
734+
return nothing, newp
735+
end
736+
# the `eltype` of the `u0` of the source
737+
srcu0 = state_values(srcvalp)
738+
T = srcu0 === nothing ? Union{} : eltype(srcu0)
739+
# promote with the tunable eltype
740+
if parameter_values(dstvalp) isa MTKParameters
741+
if !isempty(newp.tunable)
742+
T = promote_type(eltype(newp.tunable), T)
743+
end
744+
elseif !isempty(newp)
745+
T = promote_type(eltype(newp), T)
746+
end
747+
# and the eltype of the destination u0
748+
if T == eltype(state_values(dstvalp))
749+
u0 = state_values(dstvalp)
750+
elseif T != Union{}
751+
u0 = T.(state_values(dstvalp))
752+
end
753+
# apply the promotion to tunables portion
754+
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
755+
if eltype(buf) != T
756+
# only do a copy if the eltype doesn't match
757+
newbuf = similar(buf, T)
758+
copyto!(newbuf, buf)
759+
newp = repack(newbuf)
760+
end
761+
# and initials portion
762+
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
763+
if eltype(buf) != T
764+
newbuf = similar(buf, T)
765+
copyto!(newbuf, buf)
766+
newp = repack(newbuf)
767+
end
768+
return u0, newp
769+
end
770+
639771
"""
640772
$(TYPEDEF)
641773

0 commit comments

Comments
 (0)