Skip to content

Commit 6549aaa

Browse files
committed
add broadcast capabilities for multdim VoA
1 parent 2973fc6 commit 6549aaa

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

src/vector_of_array.jl

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -840,12 +840,37 @@ end
840840
# make vectorofarrays broadcastable so they aren't collected
841841
Broadcast.broadcastable(x::AbstractVectorOfArray) = x
842842

843+
# recurse through broadcast arguments and return a parent array for
844+
# the first VoA or DiffEqArray in the bc arguments
845+
function find_VoA_parent(args)
846+
arg = Base.first(args)
847+
if arg isa AbstractDiffEqArray
848+
# if first(args) is a DiffEqArray, use the underlying
849+
# field `u` of DiffEqArray as a parent array.
850+
return arg.u
851+
elseif arg isa AbstractVectorOfArray
852+
return parent(arg)
853+
else
854+
return find_VoA_parent(Base.tail(args))
855+
end
856+
end
857+
843858
@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
844859
bc = Broadcast.flatten(bc)
845-
N = narrays(bc)
846-
VectorOfArray(map(1:N) do i
847-
copy(unpack_voa(bc, i))
848-
end)
860+
861+
parent = find_VoA_parent(bc.args)
862+
863+
if parent isa AbstractVector
864+
# this is the default behavior in v3.15.0
865+
N = narrays(bc)
866+
return VectorOfArray(map(1:N) do i
867+
copy(unpack_voa(bc, i))
868+
end)
869+
else # if parent isa AbstractArray
870+
return VectorOfArray(map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
871+
copy(unpack_voa(bc, i))
872+
end)
873+
end
849874
end
850875

851876
for (type, N_expr) in [

0 commit comments

Comments
 (0)