Skip to content

Commit 2d50f7e

Browse files
fix: make ReconstructInitializeprob type-stable
1 parent 934937b commit 2d50f7e

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

src/systems/problem_utils.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,138 @@ function build_operating_point!(sys::AbstractSystem,
620620
return op, missing_unknowns, missing_pars
621621
end
622622

623+
"""
624+
$(TYPEDEF)
625+
626+
A callable struct used to reconstruct the `u0` and `p` of the initialization problem
627+
with promoted types.
628+
629+
# Fields
630+
631+
$(TYPEDFIELDS)
632+
"""
633+
struct ReconstructInitializeprob{G}
634+
"""
635+
A function which when called on the original problem returns the parameter object of
636+
the initialization problem.
637+
"""
638+
getter::G
639+
end
640+
641+
"""
642+
$(TYPEDSIGNATURES)
643+
644+
Given an index provider `indp` and a vector of symbols `syms` return a type-stable getter
645+
function by splitting `syms` into contiguous buffers where the getter of each buffer
646+
is type-stable and constructing a function that calls and concatenates the results.
647+
"""
648+
function concrete_getu(indp, syms::AbstractVector)
649+
# a list of contiguous buffer
650+
split_syms = [Any[syms[1]]]
651+
# the type of the getter of the last buffer
652+
current = typeof(getu(indp, syms[1]))
653+
for sym in syms[2:end]
654+
getter = getu(indp, sym)
655+
if typeof(getter) != current
656+
# if types don't match, build a new buffer
657+
push!(split_syms, [])
658+
current = typeof(getter)
659+
end
660+
push!(split_syms[end], sym)
661+
end
662+
split_syms = Tuple(split_syms)
663+
# the getter is now type-stable, and we can vcat it to get the full buffer
664+
return Base.Fix1(reduce, vcat) getu(indp, split_syms)
665+
end
666+
667+
"""
668+
$(TYPEDSIGNATURES)
669+
670+
Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of `dstsys`
671+
with values from `srcsys`.
672+
"""
673+
function ReconstructInitializeprob(
674+
srcsys::AbstractSystem, dstsys::AbstractSystem)
675+
@assert is_initializesystem(dstsys)
676+
if is_split(dstsys)
677+
# if we call `getu` on this (and it were able to handle empty tuples) we get the
678+
# fields of `MTKParameters` except caches.
679+
syms = reorder_parameters(dstsys, parameters(dstsys); flatten = false)
680+
# `dstsys` is an initialization system, do basically everything is a tunable
681+
# and tunables are a mix of different types in `srcsys`. No initials. Constants
682+
# are going to be constants in `srcsys`, as are `nonnumeric`.
683+
684+
# `syms[1]` is always the tunables because `srcsys` will have initials.
685+
tunable_syms = syms[1]
686+
tunable_getter = concrete_getu(srcsys, tunable_syms)
687+
rest_getters = map(Base.tail(Base.tail(syms))) do buf
688+
if buf == ()
689+
return Returns(())
690+
else
691+
return getu(srcsys, buf)
692+
end
693+
end
694+
getters = (tunable_getter, Returns(SizedVector{0, Float64}()), rest_getters...)
695+
getter = let getters = getters
696+
function _getter(valp)
697+
MTKParameters(getters[1](valp), getters[2](valp), getters[3](valp),
698+
getters[4](valp), getters[5](valp), ())
699+
end
700+
end
701+
else
702+
syms = parameters(dstsys)
703+
getter = concrete_getu(srcsys, syms)
704+
end
705+
return ReconstructInitializeprob(getter)
706+
end
707+
708+
"""
709+
$(TYPEDSIGNATURES)
710+
711+
Copy values from `srcvalp` to `dstvalp`. Returns the new `u0` and `p`.
712+
"""
713+
function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
714+
# copy parameters
715+
newp = rip.getter(srcvalp)
716+
# no `u0`, so no type-promotion
717+
if state_values(dstvalp) === nothing
718+
return nothing, newp
719+
end
720+
# the `eltype` of the `u0` of the source
721+
srcu0 = state_values(srcvalp)
722+
T = srcu0 === nothing ? Union{} : eltype(srcu0)
723+
# promote with the tunable eltype
724+
if parameter_values(dstvalp) isa MTKParameters
725+
if !isempty(newp.tunable)
726+
T = promote_type(eltype(newp.tunable), T)
727+
end
728+
elseif !isempty(newp)
729+
T = promote_type(eltype(newp), T)
730+
end
731+
# and the eltype of the destination u0
732+
if T == eltype(state_values(dstvalp))
733+
u0 = state_values(dstvalp)
734+
elseif T != Union{}
735+
u0 = T.(state_values(dstvalp))
736+
end
737+
# apply the promotion to tunables portion
738+
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
739+
if eltype(buf) != T
740+
# only do a copy if the eltype doesn't match
741+
newbuf = similar(buf, T)
742+
copyto!(newbuf, buf)
743+
newp = repack(newbuf)
744+
end
745+
# and initials portion
746+
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
747+
if eltype(buf) != T
748+
newbuf = similar(buf, T)
749+
copyto!(newbuf, buf)
750+
newp = repack(newbuf)
751+
end
752+
return u0, newp
753+
end
754+
623755
"""
624756
$(TYPEDEF)
625757

0 commit comments

Comments
 (0)