Skip to content

Commit 01dd661

Browse files
fix: fix unwrapping of views in MTKParameters reconstructor
1 parent b1818ba commit 01dd661

File tree

1 file changed

+31
-3
lines changed

1 file changed

+31
-3
lines changed

src/systems/problem_utils.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,32 @@ function concrete_getu(indp, syms::AbstractVector)
643643
return Base.Fix1(reduce, vcat) getu(indp, split_syms)
644644
end
645645

646+
"""
647+
$(TYPEDEF)
648+
649+
A callable struct which applies `p_constructor` to possibly nested arrays. It also
650+
ensures that views (including nested ones) are concretized.
651+
"""
652+
struct PConstructorApplicator{F}
653+
p_constructor::F
654+
end
655+
656+
function (pca::PConstructorApplicator)(x::AbstractArray)
657+
pca.p_constructor(x)
658+
end
659+
660+
function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray)
661+
collect(x)
662+
end
663+
664+
function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{<:AbstractArray})
665+
collect(pca.(x))
666+
end
667+
668+
function (pca::PConstructorApplicator)(x::AbstractArray{<:AbstractArray})
669+
pca.p_constructor(pca.(x))
670+
end
671+
646672
"""
647673
$(TYPEDSIGNATURES)
648674
@@ -657,6 +683,7 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
657683
"""
658684
function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem;
659685
initials = false, unwrap_initials = false, p_constructor = identity)
686+
p_constructor = PConstructorApplicator(p_constructor)
660687
# if we call `getu` on this (and it were able to handle empty tuples) we get the
661688
# fields of `MTKParameters` except caches.
662689
syms = reorder_parameters(
@@ -698,15 +725,16 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
698725
else
699726
ic = get_index_cache(dstsys)
700727
blockarrsizes = Tuple(map(ic.discrete_buffer_sizes) do bufsizes
701-
map(x -> x.length, bufsizes)
728+
p_constructor(map(x -> x.length, bufsizes))
702729
end)
703730
# discretes need to be blocked arrays
704731
# the `getu` returns a tuple of arrays corresponding to `p.discretes`
705732
# `Base.Fix1(...)` applies `p_constructor` to each of the arrays in the tuple
706733
# `Base.Fix2(...)` does `BlockedArray.(tuple_of_arrs, blockarrsizes)` returning a
707734
# tuple of `BlockedArray`s
708735
Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes)
709-
Base.Fix1(broadcast, p_constructor) getu(srcsys, syms[3])
736+
Base.Fix1(broadcast, p_constructor)
737+
getu(srcsys, syms[3])
710738
end
711739
rest_getters = map(Base.tail(Base.tail(Base.tail(syms)))) do buf
712740
if buf == ()
@@ -1307,7 +1335,7 @@ function process_SciMLProblem(
13071335
if !(pType <: AbstractArray)
13081336
pType = Array
13091337
end
1310-
p = MTKParameters(sys, op; floatT = floatT, container_type = pType, p_constructor)
1338+
p = MTKParameters(sys, op; floatT = floatT, p_constructor)
13111339
else
13121340
p = p_constructor(better_varmap_to_vars(op, ps; tofloat, container_type = pType))
13131341
end

0 commit comments

Comments
 (0)