@@ -147,27 +147,69 @@ end
147147 VA. t,VA. u
148148end
149149
150- # Broadcast
151-
152- # add_idxs(x,expr) = expr
153- # add_idxs{T<:AbstractVectorOfArray}(::Type{T},expr) = :($(expr)[i])
154- # add_idxs{T<:AbstractArray}(::Type{Vector{T}},expr) = :($(expr)[i])
155- #=
156- @generated function Base.broadcast!(f,A::AbstractVectorOfArray,B...)
157- exs = ((add_idxs(B[i],:(B[$i])) for i in eachindex(B))...)
158- :(for i in eachindex(A)
159- broadcast!(f,A[i],$(exs...))
160- end)
161- end
162-
163- @generated function Base.broadcast(f,B::Union{Number,AbstractVectorOfArray}...)
164- arr_idx = 0
165- for (i,b) in enumerate(B)
166- if b <: ArrayPartition
167- arr_idx = i
168- break
150+ # # broadcasting
151+
152+ struct VectorOfArrayStyle <: Broadcast.AbstractArrayStyle{Any} end
153+ VectorOfArrayStyle (:: Any ) = VectorOfArrayStyle ()
154+ VectorOfArrayStyle (:: Any , :: Any ) = VectorOfArrayStyle ()
155+
156+ # promotion rules
157+ # @inline function Broadcast.BroadcastStyle(::VectorOfArrayStyle{AStyle}, ::VectorOfArrayStyle{BStyle}) where {AStyle, BStyle}
158+ # VectorOfArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
159+ # end
160+ Broadcast. BroadcastStyle (:: VectorOfArrayStyle , :: Broadcast.BroadcastStyle ) = VectorOfArrayStyle ()
161+ Broadcast. BroadcastStyle (:: VectorOfArrayStyle , :: Broadcast.DefaultArrayStyle{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
162+
163+ function Broadcast. BroadcastStyle (:: Type{<:AbstractVectorOfArray{T,S}} ) where {T, S}
164+ VectorOfArrayStyle ()
165+ end
166+
167+ @inline function Base. copy (bc:: Broadcast.Broadcasted{VectorOfArrayStyle} )
168+ N = narrays (bc)
169+ x = unpack_voa (bc, 1 )
170+ VectorOfArray (map (1 : N) do i
171+ copy (unpack_voa (bc, i))
172+ end )
173+ end
174+
175+ @inline function Base. copyto! (dest:: AbstractVectorOfArray , bc:: Broadcast.Broadcasted{VectorOfArrayStyle} )
176+ N = narrays (bc)
177+ @inbounds for i in 1 : N
178+ copyto! (dest[i], unpack_voa (bc, i))
169179 end
170- end
171- :(A = similar(B[$arr_idx]); broadcast!(f,A,B...); A)
180+ dest
172181end
173- =#
182+
183+ # # broadcasting utils
184+
185+ """
186+ narrays(A...)
187+
188+ Retrieve number of arrays in the AbstractVectorOfArrays of a broadcast
189+ """
190+ narrays (A) = 0
191+ narrays (A:: AbstractVectorOfArray ) = length (A)
192+ narrays (bc:: Broadcast.Broadcasted ) = _narrays (bc. args)
193+ narrays (A, Bs... ) = common_length (narrays (A), _narrays (Bs))
194+
195+ common_length (a, b) =
196+ a == 0 ? b :
197+ (b == 0 ? a :
198+ (a == b ? a :
199+ throw (DimensionMismatch (" number of arrays must be equal" ))))
200+
201+ _narrays (args:: AbstractVectorOfArray ) = length (args)
202+ @inline _narrays (args:: Tuple ) = common_length (narrays (args[1 ]), _narrays (Base. tail (args)))
203+ _narrays (args:: Tuple{Any} ) = _narrays (args[1 ])
204+ _narrays (args:: Tuple{} ) = 0
205+
206+ # drop axes because it is easier to recompute
207+ @inline unpack_voa (bc:: Broadcast.Broadcasted{Style} , i) where Style = Broadcast. Broadcasted {Style} (bc. f, unpack_args_voa (i, bc. args))
208+ @inline unpack_voa (bc:: Broadcast.Broadcasted{VectorOfArrayStyle} , i) = Broadcast. Broadcasted (bc. f, unpack_args_voa (i, bc. args))
209+ unpack_voa (x,:: Any ) = x
210+ unpack_voa (x:: AbstractVectorOfArray , i) = x. u[i]
211+ unpack_voa (x:: AbstractArray{T,N} , i) where {T,N} = @view x[ntuple (x-> Colon (),N- 1 )... ,i]
212+
213+ @inline unpack_args_voa (i, args:: Tuple ) = (unpack_voa (args[1 ], i), unpack_args_voa (i, Base. tail (args))... )
214+ unpack_args_voa (i, args:: Tuple{Any} ) = (unpack_voa (args[1 ], i),)
215+ unpack_args_voa (:: Any , args:: Tuple{} ) = ()
0 commit comments