@@ -3,7 +3,7 @@ struct ArrayPartition{T,S<:Tuple} <: AbstractVector{T}
33end
44
55# # constructors
6-
6+ @inline ArrayPartition (f :: F , N) where F <: Function = ArrayPartition ( ntuple (f, Val (N)))
77ArrayPartition (x... ) = ArrayPartition ((x... ,))
88
99function ArrayPartition (x:: S , :: Type{Val{copy_x}} = Val{false }) where {S<: Tuple ,copy_x}
@@ -23,26 +23,25 @@ Base.similar(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(similar.(
2323Base. similar (A:: ArrayPartition , dims:: NTuple{N,Int} ) where {N} = similar (A)
2424
2525# similar array partition of common type
26- @generated function Base. similar (A:: ArrayPartition , :: Type{T} ) where {T}
26+ @inline function Base. similar (A:: ArrayPartition , :: Type{T} ) where {T}
2727 N = npartitions (A)
28- expr = :(similar (A. x[i], T))
29-
30- build_arraypartition (N, expr)
28+ ArrayPartition (i-> similar (A. x[i], T), N)
3129end
3230
3331# ignore dims since array partitions are vectors
3432Base. similar (A:: ArrayPartition , :: Type{T} , dims:: NTuple{N,Int} ) where {T,N} = similar (A, T)
3533
3634# similar array partition with different types
37- @generated function Base. similar (A:: ArrayPartition , :: Type{T} , :: Type{S} ,
38- R:: Vararg{Type} ) where {T,S}
35+ function Base. similar (A:: ArrayPartition , :: Type{T} , :: Type{S} , R:: DataType... ) where {T, S}
3936 N = npartitions (A)
4037 N != length (R) + 2 &&
4138 throw (DimensionMismatch (" number of types must be equal to number of partitions" ))
4239
43- types = (T, S, parameter .(R)... ) # new types
44- expr = :(similar (A. x[i], ($ types)[i]))
45- build_arraypartition (N, expr)
40+ types = (T, S, R... ) # new types
41+ @inline function f (i)
42+ similar (A. x[i], types[i])
43+ end
44+ ArrayPartition (f, N)
4645end
4746
4847Base. copy (A:: ArrayPartition{T,S} ) where {T,S} = ArrayPartition {T,S} (copy .(A. x))
@@ -52,17 +51,16 @@ Base.zero(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(zero.(A.x))
5251# ignore dims since array partitions are vectors
5352Base. zero (A:: ArrayPartition , dims:: NTuple{N,Int} ) where {N} = zero (A)
5453
55-
56-
5754# # ones
5855
5956# special to work with units
60- @generated function Base. ones (A:: ArrayPartition )
57+ function Base. ones (A:: ArrayPartition )
6158 N = npartitions (A)
62-
63- expr = :(fill! (similar (A. x[i]), oneunit (eltype (A. x[i]))))
64-
65- build_arraypartition (N, expr)
59+ out = similar (A)
60+ for i in 1 : N
61+ fill! (out. x[i], oneunit (eltype (out. x[i])))
62+ end
63+ out
6664end
6765
6866# ignore dims since array partitions are vectors
@@ -72,50 +70,32 @@ Base.ones(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = ones(A)
7270
7371for op in (:+ , :- )
7472 @eval begin
75- @generated function Base. $op (A:: ArrayPartition , B:: ArrayPartition )
76- N = npartitions (A, B)
77- expr = :($ ($ op). (A. x[i], B. x[i]))
78-
79- build_arraypartition (N, expr)
73+ function Base. $op (A:: ArrayPartition , B:: ArrayPartition )
74+ Base. broadcast ($ op, A, B)
8075 end
8176
82- @generated function Base. $op (A:: ArrayPartition , B:: Number )
83- N = npartitions (A)
84- expr = :($ ($ op). (A. x[i], B))
85-
86- build_arraypartition (N, expr)
77+ function Base. $op (A:: ArrayPartition , B:: Number )
78+ Base. broadcast ($ op, A, B)
8779 end
8880
89- @generated function Base. $op (A:: Number , B:: ArrayPartition )
90- N = npartitions (B)
91- expr = :($ ($ op). (A, B. x[i]))
92-
93- build_arraypartition (N, expr)
81+ function Base. $op (A:: Number , B:: ArrayPartition )
82+ Base. broadcast ($ op, A, B)
9483 end
9584 end
9685end
9786
9887for op in (:* , :/ )
99- @eval @generated function Base. $op (A:: ArrayPartition , B:: Number )
100- N = npartitions (A)
101- expr = :($ ($ op). (A. x[i], B))
102-
103- build_arraypartition (N, expr)
88+ @eval function Base. $op (A:: ArrayPartition , B:: Number )
89+ Base. broadcast ($ op, A, B)
10490 end
10591end
10692
107- @generated function Base.:* (A:: Number , B:: ArrayPartition )
108- N = npartitions (B)
109- expr = :((* ). (A, B. x[i]))
110-
111- build_arraypartition (N, expr)
93+ function Base.:* (A:: Number , B:: ArrayPartition )
94+ Base. broadcast (* , A, B)
11295end
11396
114- @generated function Base.:\ (A:: Number , B:: ArrayPartition )
115- N = npartitions (B)
116- expr = :((/ ). (B. x[i], A))
117-
118- build_arraypartition (N, expr)
97+ function Base.:\ (A:: Number , B:: ArrayPartition )
98+ Base. broadcast (/ , B, A)
11999end
120100
121101# # Functional Constructs
@@ -232,90 +212,72 @@ Base.show(io::IO, m::MIME"text/plain", A::ArrayPartition) = show(io, m, A.x)
232212
233213# # broadcasting
234214
235- struct APStyle <: Broadcast.BroadcastStyle end
236- Base. BroadcastStyle (:: Type{<:ArrayPartition} ) = Broadcast. ArrayStyle {ArrayPartition} ()
237- Base. BroadcastStyle (:: Broadcast.ArrayStyle{ArrayPartition} ,:: Broadcast.ArrayStyle ) = Broadcast. ArrayStyle {ArrayPartition} ()
238- Base. BroadcastStyle (:: Broadcast.ArrayStyle ,:: Broadcast.ArrayStyle{ArrayPartition} ) = Broadcast. ArrayStyle {ArrayPartition} ()
239- Base. similar (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayPartition}} ,:: Type{ElType} ) where ElType = similar (bc)
215+ struct ArrayPartitionStyle{Style <: Broadcast.BroadcastStyle } <: Broadcast.AbstractArrayStyle{Any} end
216+ ArrayPartitionStyle (:: S ) where {S} = ArrayPartitionStyle {S} ()
217+ ArrayPartitionStyle (:: S , :: Val{N} ) where {S,N} = ArrayPartitionStyle (S (Val (N)))
218+ ArrayPartitionStyle (:: Val{N} ) where N = ArrayPartitionStyle {Broadcast.DefaultArrayStyle{N}} ()
240219
241- function Base . copy (bc :: Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayPartition}} )
242- ret = Broadcast. flatten (bc)
243- __broadcast (ret . f,ret . args ... )
220+ # promotion rules
221+ function Broadcast. BroadcastStyle ( :: ArrayPartitionStyle{AStyle} , :: ArrayPartitionStyle{BStyle} ) where {AStyle, BStyle}
222+ ArrayPartitionStyle (Broadcast . BroadcastStyle ( AStyle (), BStyle ()) )
244223end
245224
246- @generated function __broadcast (f,as ... )
247-
248- # common number of partitions
249- N = npartitions (as ... )
225+ combine_styles (args :: Tuple{} ) = Broadcast . DefaultArrayStyle {0} ( )
226+ combine_styles (args :: Tuple{Any} ) = Broadcast . result_style (Broadcast . BroadcastStyle (args[ 1 ]))
227+ combine_styles (args :: Tuple{Any, Any} ) = Broadcast . result_style (Broadcast . BroadcastStyle (args[ 1 ]), Broadcast . BroadcastStyle (args[ 2 ]))
228+ @inline combine_styles (args :: Tuple ) = Broadcast . result_style (Broadcast . BroadcastStyle (args[ 1 ]), combine_styles (Base . tail (args)) )
250229
251- # broadcast partitions separately
252- expr = :(broadcast (f,
253- # index partitions
254- $ ((as[d] <: ArrayPartition ? :(as[$ d]. x[i]) : :(as[$ d])
255- for d in 1 : length (as)). .. )))
256- build_arraypartition (N, expr)
230+ function Broadcast. BroadcastStyle (:: Type{ArrayPartition{T,S}} ) where {T, S}
231+ Style = combine_styles ((S. parameters... ,))
232+ ArrayPartitionStyle (Style)
257233end
258234
259- function Base. copyto! (dest:: AbstractArray ,bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayPartition}} )
260- ret = Broadcast. flatten (bc)
261- __broadcast! (ret. f,dest,ret. args... )
262- end
263-
264- @generated function __broadcast! (f, dest, as... )
265- # common number of partitions
266- N = npartitions (dest, as... )
267-
268- # broadcast partitions separately
269- quote
270- for i in 1 : $ N
271- broadcast! (f, dest. x[i],
272- # index partitions
273- $ ((as[d] <: ArrayPartition ? :(as[$ d]. x[i]) : :(as[$ d])
274- for d in 1 : length (as)). .. ))
275- end
276- dest
235+ @inline function Base. copy (bc:: Broadcast.Broadcasted{ArrayPartitionStyle{Style}} ) where Style
236+ N = npartitions (bc)
237+ @inline function f (i)
238+ copy (unpack (bc, i))
277239 end
240+ ArrayPartition (f, N)
278241end
279242
280- # # utils
281-
282- """
283- build_arraypartition(N::Int, expr::Expr)
284-
285- Build `ArrayPartition` consisting of `N` partitions, each the result of an evaluation of
286- `expr` with variable `i` set to the partition index in the range of 1 to `N`.
287-
288- This can help to write a type-stable method in cases in which the correct return type can
289- can not be inferred for a simpler implementation with generators.
290- """
291- function build_arraypartition (N:: Int , expr:: Expr )
292- quote
293- @Base . nexprs $ N i-> (A_i = $ expr)
294- partitions = @Base . ncall $ N tuple i-> A_i
295- ArrayPartition (partitions)
243+ @inline function Base. copyto! (dest:: ArrayPartition , bc:: Broadcast.Broadcasted )
244+ N = npartitions (dest, bc)
245+ for i in 1 : N
246+ copyto! (dest. x[i], unpack (bc, i))
296247 end
248+ dest
297249end
298250
251+ # # broadcasting utils
252+
299253"""
300254 npartitions(A...)
301255
302256Retrieve number of partitions of `ArrayPartitions` in `A...`, or throw an error if there are
303257`ArrayPartitions` with a different number of partitions.
304258"""
305259npartitions (A) = 0
306- npartitions (:: Type{ArrayPartition{T,S}} ) where {T,S} = length (S. parameters)
307- npartitions (A, B... ) = common_number (npartitions (A), npartitions (B... ))
260+ npartitions (A:: ArrayPartition ) = length (A. x)
261+ npartitions (bc:: Broadcast.Broadcasted ) = _npartitions (bc. args)
262+ npartitions (A, Bs... ) = common_number (npartitions (A), _npartitions (Bs))
263+
264+ @inline _npartitions (args:: Tuple ) = common_number (npartitions (args[1 ]), _npartitions (Base. tail (args)))
265+ _npartitions (args:: Tuple{Any} ) = npartitions (args[1 ])
266+ _npartitions (args:: Tuple{} ) = 0
267+
268+ # drop axes because it is easier to recompute
269+ @inline unpack (bc:: Broadcast.Broadcasted{Style} , i) where Style = Broadcast. Broadcasted {Style} (bc. f, unpack_args (i, bc. args))
270+ @inline unpack (bc:: Broadcast.Broadcasted{ArrayPartitionStyle{Style}} , i) where Style = Broadcast. Broadcasted {Style} (bc. f, unpack_args (i, bc. args))
271+ unpack (x,:: Any ) = x
272+ unpack (x:: ArrayPartition , i) = x. x[i]
273+
274+ @inline unpack_args (i, args:: Tuple ) = (unpack (args[1 ], i), unpack_args (i, Base. tail (args))... )
275+ unpack_args (i, args:: Tuple{Any} ) = (unpack (args[1 ], i),)
276+ unpack_args (:: Any , args:: Tuple{} ) = ()
308277
278+ # # utils
309279common_number (a, b) =
310280 a == 0 ? b :
311281 (b == 0 ? a :
312282 (a == b ? a :
313283 throw (DimensionMismatch (" number of partitions must be equal" ))))
314-
315- """
316- parameter(::Type{T})
317-
318- Return type `T` of singleton.
319- """
320- parameter (:: Type{T} ) where {T} = T
321- parameter (:: Type{Type{T}} ) where {T} = T
0 commit comments