@@ -149,30 +149,24 @@ end
149149
150150# # broadcasting
151151
152- struct VectorOfArrayStyle <: Broadcast.AbstractArrayStyle{Any} end
153- VectorOfArrayStyle (:: Any ) = VectorOfArrayStyle ()
154- VectorOfArrayStyle (:: Any , :: Any ) = VectorOfArrayStyle ()
155-
156- # promotion rules
157- # @inline function Broadcast.BroadcastStyle(::VectorOfArrayStyle{AStyle}, ::VectorOfArrayStyle{BStyle}) where {AStyle, BStyle}
158- # VectorOfArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
159- # end
160- Broadcast. BroadcastStyle (:: VectorOfArrayStyle , :: Broadcast.BroadcastStyle ) = VectorOfArrayStyle ()
161- Broadcast. BroadcastStyle (:: VectorOfArrayStyle , :: Broadcast.DefaultArrayStyle{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
162-
163- function Broadcast. BroadcastStyle (:: Type{<:AbstractVectorOfArray{T,S}} ) where {T, S}
164- VectorOfArrayStyle ()
165- end
152+ struct VectorOfArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end # N is only used when voa sees other abstract arrays
153+ VectorOfArrayStyle (:: Val{N} ) where N = VectorOfArrayStyle {N} ()
154+
155+ # The order is important here. We want to override Base.Broadcast.DefaultArrayStyle to return another Base.Broadcast.DefaultArrayStyle.
156+ Broadcast. BroadcastStyle (:: VectorOfArrayStyle{N} , a:: Base.Broadcast.DefaultArrayStyle{M} ) where {M,N} = Base. Broadcast. DefaultArrayStyle (Val (max (M, N)))
157+ Broadcast. BroadcastStyle (:: VectorOfArrayStyle{N} , a:: Base.Broadcast.AbstractArrayStyle{M} ) where {M,N} = typeof (a)(Val (max (M, N)))
158+ Broadcast. BroadcastStyle (:: VectorOfArrayStyle{M} , :: VectorOfArrayStyle{N} ) where {M,N} = VectorOfArrayStyle (Val (max (M, N)))
159+ Broadcast. BroadcastStyle (:: Type{<:AbstractVectorOfArray{T,N}} ) where {T,N} = VectorOfArrayStyle {N} ()
166160
167- @inline function Base. copy (bc:: Broadcast.Broadcasted{VectorOfArrayStyle} )
161+ @inline function Base. copy (bc:: Broadcast.Broadcasted{<: VectorOfArrayStyle} )
168162 N = narrays (bc)
169163 x = unpack_voa (bc, 1 )
170- VectorOfArray (map (1 : N) do i
164+ return VectorOfArray (map (1 : N) do i
171165 copy (unpack_voa (bc, i))
172166 end )
173167end
174168
175- @inline function Base. copyto! (dest:: AbstractVectorOfArray , bc:: Broadcast.Broadcasted{VectorOfArrayStyle} )
169+ @inline function Base. copyto! (dest:: AbstractVectorOfArray , bc:: Broadcast.Broadcasted{<: VectorOfArrayStyle} )
176170 N = narrays (bc)
177171 @inbounds for i in 1 : N
178172 copyto! (dest[i], unpack_voa (bc, i))
@@ -205,7 +199,7 @@ _narrays(args::Tuple{}) = 0
205199
206200# drop axes because it is easier to recompute
207201@inline unpack_voa (bc:: Broadcast.Broadcasted{Style} , i) where Style = Broadcast. Broadcasted {Style} (bc. f, unpack_args_voa (i, bc. args))
208- @inline unpack_voa (bc:: Broadcast.Broadcasted{VectorOfArrayStyle} , i) = Broadcast. Broadcasted (bc. f, unpack_args_voa (i, bc. args))
202+ @inline unpack_voa (bc:: Broadcast.Broadcasted{<: VectorOfArrayStyle} , i) = Broadcast. Broadcasted (bc. f, unpack_args_voa (i, bc. args))
209203unpack_voa (x,:: Any ) = x
210204unpack_voa (x:: AbstractVectorOfArray , i) = x. u[i]
211205unpack_voa (x:: AbstractArray{T,N} , i) where {T,N} = @view x[ntuple (x-> Colon (),N- 1 )... ,i]
0 commit comments