Skip to content

Commit 3a28e4c

Browse files
committed
More robust broadcast
1 parent 9a7c1e4 commit 3a28e4c

File tree

1 file changed

+12
-18
lines changed

1 file changed

+12
-18
lines changed

src/vector_of_array.jl

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -149,30 +149,24 @@ end
149149

150150
## broadcasting
151151

152-
struct VectorOfArrayStyle <: Broadcast.AbstractArrayStyle{Any} end
153-
VectorOfArrayStyle(::Any) = VectorOfArrayStyle()
154-
VectorOfArrayStyle(::Any, ::Any) = VectorOfArrayStyle()
155-
156-
# promotion rules
157-
#@inline function Broadcast.BroadcastStyle(::VectorOfArrayStyle{AStyle}, ::VectorOfArrayStyle{BStyle}) where {AStyle, BStyle}
158-
# VectorOfArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
159-
#end
160-
Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.BroadcastStyle) = VectorOfArrayStyle()
161-
Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}()
162-
163-
function Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,S}}) where {T, S}
164-
VectorOfArrayStyle()
165-
end
152+
struct VectorOfArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end # N is only used when voa sees other abstract arrays
153+
VectorOfArrayStyle(::Val{N}) where N = VectorOfArrayStyle{N}()
154+
155+
# The order is important here. We want to override Base.Broadcast.DefaultArrayStyle to return another Base.Broadcast.DefaultArrayStyle.
156+
Broadcast.BroadcastStyle(::VectorOfArrayStyle{N}, a::Base.Broadcast.DefaultArrayStyle{M}) where {M,N} = Base.Broadcast.DefaultArrayStyle(Val(max(M, N)))
157+
Broadcast.BroadcastStyle(::VectorOfArrayStyle{N}, a::Base.Broadcast.AbstractArrayStyle{M}) where {M,N} = typeof(a)(Val(max(M, N)))
158+
Broadcast.BroadcastStyle(::VectorOfArrayStyle{M}, ::VectorOfArrayStyle{N}) where {M,N} = VectorOfArrayStyle(Val(max(M, N)))
159+
Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,N}}) where {T,N} = VectorOfArrayStyle{N}()
166160

167-
@inline function Base.copy(bc::Broadcast.Broadcasted{VectorOfArrayStyle})
161+
@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
168162
N = narrays(bc)
169163
x = unpack_voa(bc, 1)
170-
VectorOfArray(map(1:N) do i
164+
return VectorOfArray(map(1:N) do i
171165
copy(unpack_voa(bc, i))
172166
end)
173167
end
174168

175-
@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{VectorOfArrayStyle})
169+
@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
176170
N = narrays(bc)
177171
@inbounds for i in 1:N
178172
copyto!(dest[i], unpack_voa(bc, i))
@@ -205,7 +199,7 @@ _narrays(args::Tuple{}) = 0
205199

206200
# drop axes because it is easier to recompute
207201
@inline unpack_voa(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args))
208-
@inline unpack_voa(bc::Broadcast.Broadcasted{VectorOfArrayStyle}, i) = Broadcast.Broadcasted(bc.f, unpack_args_voa(i, bc.args))
202+
@inline unpack_voa(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle}, i) = Broadcast.Broadcasted(bc.f, unpack_args_voa(i, bc.args))
209203
unpack_voa(x,::Any) = x
210204
unpack_voa(x::AbstractVectorOfArray, i) = x.u[i]
211205
unpack_voa(x::AbstractArray{T,N}, i) where {T,N} = @view x[ntuple(x->Colon(),N-1)...,i]

0 commit comments

Comments
 (0)