Skip to content

Commit a35ae7d

Browse files
fix: handle discretes properly in get_mtkparameters_reconstructor
1 parent 4631a1e commit a35ae7d

File tree

1 file changed

+32
-15
lines changed

1 file changed

+32
-15
lines changed

src/systems/problem_utils.jl

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -672,13 +672,6 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
672672
else
673673
p_constructor concrete_getu(srcsys, tunable_syms)
674674
end
675-
rest_getters = map(Base.tail(Base.tail(syms))) do buf
676-
if buf == ()
677-
return Returns(())
678-
else
679-
return Base.Fix1(broadcast, p_constructor) getu(srcsys, buf)
680-
end
681-
end
682675
initials_getter = if initials && !isempty(syms[2])
683676
initsyms = Vector{Any}(syms[2])
684677
allsyms = Set(all_symbols(srcsys))
@@ -700,7 +693,29 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
700693
else
701694
Returns(SizedVector{0, Float64}())
702695
end
703-
getters = (tunable_getter, initials_getter, rest_getters...)
696+
discs_getter = if isempty(syms[3])
697+
Returns(())
698+
else
699+
ic = get_index_cache(dstsys)
700+
blockarrsizes = Tuple(map(ic.discrete_buffer_sizes) do bufsizes
701+
map(x -> x.length, bufsizes)
702+
end)
703+
# discretes need to be blocked arrays
704+
# the `getu` returns a tuple of arrays corresponding to `p.discretes`
705+
# `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple
706+
# `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a
707+
# tuple of `BlockedArray`s
708+
Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes)
709+
Base.Fix1(broadcast, p_constructor) getu(srcsys, syms[3])
710+
end
711+
rest_getters = map(Base.tail(Base.tail(Base.tail(syms)))) do buf
712+
if buf == ()
713+
return Returns(())
714+
else
715+
return Base.Fix1(broadcast, p_constructor) getu(srcsys, buf)
716+
end
717+
end
718+
getters = (tunable_getter, initials_getter, discs_getter, rest_getters...)
704719
getter = let getters = getters
705720
function _getter(valp, initprob)
706721
oldcache = parameter_values(initprob).caches
@@ -772,12 +787,14 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
772787
copyto!(newbuf, buf)
773788
newp = repack(newbuf)
774789
end
775-
# and initials portion
776-
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
777-
if eltype(buf) != T
778-
newbuf = similar(buf, T)
779-
copyto!(newbuf, buf)
780-
newp = repack(newbuf)
790+
if newp isa MTKParameters
791+
# and initials portion
792+
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
793+
if eltype(buf) != T
794+
newbuf = similar(buf, T)
795+
copyto!(newbuf, buf)
796+
newp = repack(newbuf)
797+
end
781798
end
782799
return u0, newp
783800
end
@@ -793,7 +810,7 @@ function construct_initializeprobpmap(
793810
@assert is_initializesystem(initsys)
794811
if is_split(sys)
795812
return let getter = get_mtkparameters_reconstructor(
796-
initsys, sys; initials = true, p_constructor)
813+
initsys, sys; initials = true, unwrap_initials = true, p_constructor)
797814
function initprobpmap_split(prob, initsol)
798815
getter(initsol, prob)
799816
end

0 commit comments

Comments
 (0)