1- immutable ArrayPartition{T} <: AbstractVector{Any }
2- x:: T
1+ struct ArrayPartition{T,S <: Tuple } <: AbstractVector{T }
2+ x:: S
33end
4+
5+ # # constructors
6+
47ArrayPartition (x... ) = ArrayPartition ((x... ))
5- function ArrayPartition {T,T2<:Tuple} (x:: T2 ,:: Type{Val{T}} = Val{false })
6- if T
7- return ArrayPartition {T2} (((copy (a) for a in x). .. ))
8+
9+ function ArrayPartition (x:: S , :: Type{Val{copy}} = Val{false }) where {S<: Tuple ,copy}
10+ T = promote_type (eltype .(x)... )
11+
12+ if copy
13+ return ArrayPartition {T,S} (copy .(x))
814 else
9- return ArrayPartition {T2} ((x ... ) )
15+ return ArrayPartition {T,S} (x )
1016 end
1117end
12- Base. similar (A:: ArrayPartition ) = ArrayPartition ((similar .(A. x)). .. )
13- Base. similar (A:: ArrayPartition , dims:: Tuple ) = ArrayPartition ((similar .(A. x)). .. ) # Ignore dims / indices since it's a vector
14- Base. similar {T} (A:: ArrayPartition , :: Type{T} ) = ArrayPartition (similar .(A. x, T)... )
15- Base. similar {T} (A:: ArrayPartition , :: Type{T} , dims:: Tuple ) = ArrayPartition (similar .(A. x, T, dims)... )
1618
17- Base. zeros (A:: ArrayPartition ) = ArrayPartition ((zeros (x) for x in A. x). .. )
18- Base. zeros (A:: ArrayPartition , dims:: Tuple ) = ArrayPartition ((zeros .(A. x)). .. ) # Ignore dims / indices since it's a vector
19- Base. zeros {T} (A:: ArrayPartition , :: Type{T} ) = ArrayPartition (zeros .(A. x, T)... )
20- Base. zeros {T} (A:: ArrayPartition , :: Type{T} , dims:: Tuple ) = ArrayPartition (zeros .(A. x, T, dims)... )
19+ # # similar array partitions
2120
22- Base. copy (A:: ArrayPartition ) = Base. similar (A)
23- Base. eltype (A:: ArrayPartition ) = eltype (A. x[1 ])
21+ Base. similar (A:: ArrayPartition{T,S} ) where {T,S} = ArrayPartition {T,S} (similar .(A. x))
2422
25- # Special to work with units
26- function Base. ones (A:: ArrayPartition )
27- B = similar (A:: ArrayPartition )
28- for i in eachindex (A. x)
29- B. x[i] .= eltype (A. x[i])(one (first (A. x[i])))
30- end
31- B
32- end
33-
34- Base.:+ (A:: ArrayPartition , B:: ArrayPartition ) =
35- ArrayPartition ((x .+ y for (x,y) in zip (A. x,B. x)). .. )
36- Base.:+ (A:: Number , B:: ArrayPartition ) = ArrayPartition ((A .+ x for x in B. x). .. )
37- Base.:+ (A:: ArrayPartition , B:: Number ) = ArrayPartition ((B .+ x for x in A. x). .. )
38- Base.:- (A:: ArrayPartition , B:: ArrayPartition ) =
39- ArrayPartition ((x .- y for (x,y) in zip (A. x,B. x)). .. )
40- Base.:- (A:: Number , B:: ArrayPartition ) = ArrayPartition ((A .- x for x in B. x). .. )
41- Base.:- (A:: ArrayPartition , B:: Number ) = ArrayPartition ((x .- B for x in A. x). .. )
42- Base.:* (A:: Number , B:: ArrayPartition ) = ArrayPartition ((A .* x for x in B. x). .. )
43- Base.:* (A:: ArrayPartition , B:: Number ) = ArrayPartition ((x .* B for x in A. x). .. )
44- Base.:/ (A:: ArrayPartition , B:: Number ) = ArrayPartition ((x ./ B for x in A. x). .. )
45- Base.:\ (A:: Number , B:: ArrayPartition ) = ArrayPartition ((x ./ A for x in B. x). .. )
46-
47- @inline function Base. getindex ( A:: ArrayPartition ,i:: Int )
48- @boundscheck i > length (A) && throw (BoundsError (" Index out of bounds" ))
23+ # ignore dims since array partitions are vectors
24+ Base. similar (A:: ArrayPartition , dims:: NTuple{N,Int} ) where {N} = similar (A)
25+
26+ # similar array partition of common type
27+ @generated function Base. similar (A:: ArrayPartition , :: Type{T} ) where {T}
28+ N = npartitions (A)
29+ expr = :(similar (A. x[i], T))
30+
31+ build_arraypartition (N, expr)
32+ end
33+
34+ # ignore dims since array partitions are vectors
35+ Base. similar (A:: ArrayPartition , :: Type{T} , dims:: NTuple{N,Int} ) where {T,N} = similar (A, T)
36+
37+ # similar array partition with different types
38+ @generated function Base. similar (A:: ArrayPartition , :: Type{T} , :: Type{S} ,
39+ R:: Vararg{Type} ) where {T,S}
40+ N = npartitions (A)
41+ N != length (R) + 2 &&
42+ throw (DimensionMismatch (" number of types must be equal to number of partitions" ))
43+
44+ types = (T, S, parameter .(R)) # new types
45+ expr = :(similar (A. x[i], ($ types)[i]))
46+
47+ build_arraypartition (N, expr)
48+ end
49+
50+ Base. copy (A:: ArrayPartition{T,S} ) where {T,S} = ArrayPartition {T,S} (copy .(A. x))
51+
52+ # # zeros
53+
54+ Base. zeros (A:: ArrayPartition{T,S} ) where {T,S} = ArrayPartition {T,S} (zeros .(A. x))
55+
56+ # ignore dims since array partitions are vectors
57+ Base. zeros (A:: ArrayPartition , dims:: NTuple{N,Int} ) where {N} = zeros (A)
58+
59+ # # ones
60+
61+ # special to work with units
62+ @generated function Base. ones (A:: ArrayPartition )
63+ N = npartitions (A)
64+
65+ expr = :(fill! (similar (A. x[i]), oneunit (eltype (A. x[i]))))
66+
67+ build_arraypartition (N, expr)
68+ end
69+
70+ # ignore dims since array partitions are vectors
71+ Base. ones (A:: ArrayPartition , dims:: NTuple{N,Int} ) where {N} = ones (A)
72+
73+ # # vector space operations
74+
75+ for op in (:+ , :- )
76+ @eval begin
77+ @generated function Base. $op (A:: ArrayPartition , B:: ArrayPartition )
78+ N = npartitions (A, B)
79+ expr = :($ ($ op). (A. x[i], B. x[i]))
80+
81+ build_arraypartition (N, expr)
82+ end
83+
84+ @generated function Base. $op (A:: ArrayPartition , B:: Number )
85+ N = npartitions (A)
86+ expr = :($ ($ op). (A. x[i], B))
87+
88+ build_arraypartition (N, expr)
89+ end
90+
91+ @generated function Base. $op (A:: Number , B:: ArrayPartition )
92+ N = npartitions (B)
93+ expr = :($ ($ op). (A, B. x[i]))
94+
95+ build_arraypartition (N, expr)
96+ end
97+ end
98+ end
99+
100+ for op in (:* , :/ )
101+ @eval @generated function Base. $op (A:: ArrayPartition , B:: Number )
102+ N = npartitions (A)
103+ expr = :($ ($ op). (A. x[i], B))
104+
105+ build_arraypartition (N, expr)
106+ end
107+ end
108+
109+ @generated function Base.:* (A:: Number , B:: ArrayPartition )
110+ N = npartitions (B)
111+ expr = :((* ). (A, B. x[i]))
112+
113+ build_arraypartition (N, expr)
114+ end
115+
116+ @generated function Base.:\ (A:: Number , B:: ArrayPartition )
117+ N = npartitions (B)
118+ expr = :((/ ). (B. x[i], A))
119+
120+ build_arraypartition (N, expr)
121+ end
122+
123+ # # indexing
124+
125+ @inline function Base. getindex (A:: ArrayPartition , i:: Int )
126+ @boundscheck checkbounds (A, i)
49127 @inbounds for j in 1 : length (A. x)
50128 i -= length (A. x[j])
51129 if i <= 0
52130 return A. x[j][length (A. x[j])+ i]
53131 end
54132 end
55133end
56- Base. getindex ( A:: ArrayPartition ,:: Colon ) = [A[i] for i in 1 : length (A)]
134+
135+ """
136+ getindex(A::ArrayPartition, i::Int, j...)
137+
138+ Return the entry at index `j...` of the `i`th partition of `A`.
139+ """
140+ @inline function Base. getindex (A:: ArrayPartition , i:: Int , j... )
141+ @boundscheck 0 < i <= length (A. x) || throw (BoundsError (A. x, i))
142+ @inbounds b = A. x[i]
143+ @boundscheck checkbounds (b, j... )
144+ @inbounds return b[j... ]
145+ end
146+
147+ """
148+ getindex(A::ArrayPartition, ::Colon)
149+
150+ Return vector with all elements of array partition `A`.
151+ """
152+ Base. getindex (A:: ArrayPartition{T,S} , :: Colon ) where {T,S} = T[a for a in Chain (A. x)]
153+
57154@inline function Base. setindex! (A:: ArrayPartition , v, i:: Int )
58- @boundscheck i > length (A) && throw ( BoundsError ( " Index out of bounds " ) )
155+ @boundscheck checkbounds (A, i )
59156 @inbounds for j in 1 : length (A. x)
60157 i -= length (A. x[j])
61158 if i <= 0
@@ -64,28 +161,47 @@ Base.getindex( A::ArrayPartition,::Colon) = [A[i] for i in 1:length(A)]
64161 end
65162 end
66163end
67- Base. getindex ( A:: ArrayPartition , i:: Int... ) = A. x[i[1 ]][Base. tail (i)... ]
68- Base. setindex! (A:: ArrayPartition , v, i:: Int... ) = A. x[i[1 ]][Base. tail (i)... ]= v
69164
70- function recursivecopy! (A:: ArrayPartition ,B:: ArrayPartition )
71- for (a,b) in zip (A. x,B. x)
72- copy! (a,b)
165+ """
166+ setindex!(A::ArrayPartition, v, i::Int, j...)
167+
168+ Set the entry at index `j...` of the `i`th partition of `A` to `v`.
169+ """
170+ @inline function Base. setindex! (A:: ArrayPartition , v, i:: Int , j... )
171+ @boundscheck 0 < i <= length (A. x) || throw (BoundsError (A. x, i))
172+ @inbounds b = A. x[i]
173+ @boundscheck checkbounds (b, j... )
174+ @inbounds b[j... ] = v
175+ end
176+
177+ # # recursive methods
178+
179+ function recursivecopy! (A:: ArrayPartition , B:: ArrayPartition )
180+ for (a, b) in zip (A. x, B. x)
181+ recursivecopy! (a, b)
73182 end
74183end
75184
76- recursive_one (A:: ArrayPartition ) = recursive_one (first (A. x))
77185recursive_mean (A:: ArrayPartition ) = mean ((recursive_mean (x) for x in A. x))
78- Base. zero (A:: ArrayPartition ) = zero (first (A. x))
79- Base. first (A:: ArrayPartition ) = first (first (A. x))
80186
81- Base. start (A:: ArrayPartition ) = start (chain (A. x... ))
82- Base. next (A:: ArrayPartition ,state) = next (chain (A. x... ),state)
83- Base. done (A:: ArrayPartition ,state) = done (chain (A. x... ),state)
187+ # note: consider only first partition for recursive one and eltype
188+ recursive_one (A:: ArrayPartition ) = recursive_one (first (A. x))
189+ recursive_eltype (A:: ArrayPartition ) = recursive_eltype (first (A. x))
190+
191+ # # iteration
192+
193+ Base. start (A:: ArrayPartition ) = start (Chain (A. x))
194+ Base. next (A:: ArrayPartition ,state) = next (Chain (A. x),state)
195+ Base. done (A:: ArrayPartition ,state) = done (Chain (A. x),state)
84196
85197Base. length (A:: ArrayPartition ) = sum ((length (x) for x in A. x))
86198Base. size (A:: ArrayPartition ) = (length (A),)
87- Base. isempty (A:: ArrayPartition ) = (length (A) == 0 )
88- Base. eachindex (A:: ArrayPartition ) = Base. OneTo (length (A))
199+
200+ # redefine first and last to avoid slow and not type-stable indexing
201+ Base. first (A:: ArrayPartition ) = first (first (A. x))
202+ Base. last (A:: ArrayPartition ) = last (last (A. x))
203+
204+ # # display
89205
90206# restore the type rendering in Juno
91207Juno. @render Juno. Inline x:: ArrayPartition begin
@@ -97,23 +213,83 @@ Base.show(io::IO,A::ArrayPartition) = (Base.show.(io,A.x); nothing)
97213Base. display (A:: ArrayPartition ) = (println (summary (A));display .(A. x);nothing )
98214Base. display (io:: IO ,A:: ArrayPartition ) = (println (summary (A));display .(io,A. x);nothing )
99215
100- add_idxs (x,expr) = expr
101- add_idxs {T<:ArrayPartition} (:: Type{T} ,expr) = :($ (expr). x[i])
216+ # # broadcasting
217+
218+ Base. Broadcast. _containertype (:: Type{<:ArrayPartition} ) = ArrayPartition
219+ Base. Broadcast. promote_containertype (:: Type{ArrayPartition} , :: Type ) = ArrayPartition
220+ Base. Broadcast. promote_containertype (:: Type , :: Type{ArrayPartition} ) = ArrayPartition
221+ Base. Broadcast. promote_containertype (:: Type{ArrayPartition} , :: Type{ArrayPartition} ) = ArrayPartition
222+ Base. Broadcast. promote_containertype (:: Type{ArrayPartition} , :: Type{Array} ) = ArrayPartition
223+ Base. Broadcast. promote_containertype (:: Type{Array} , :: Type{ArrayPartition} ) = ArrayPartition
224+
225+ @generated function Base. Broadcast. broadcast_c (f, :: Type{ArrayPartition} , as... )
226+ # common number of partitions
227+ N = npartitions (as... )
228+
229+ # broadcast partitions separately
230+ expr = :(broadcast (f,
231+ # index partitions
232+ $ ((as[d] <: ArrayPartition ? :(as[$ d]. x[i]) : :(as[$ d])
233+ for d in 1 : length (as)). .. )))
102234
103- @generated function Base. broadcast! (f,A:: ArrayPartition ,B... )
104- exs = ((add_idxs (B[i],:(B[$ i])) for i in eachindex (B)). .. )
105- :(for i in eachindex (A. x)
106- broadcast! (f,A. x[i],$ (exs... ))
107- end )
235+ build_arraypartition (N, expr)
108236end
109237
110- @generated function Base. broadcast (f,B:: Union{Number,ArrayPartition} ...)
111- arr_idx = 0
112- for (i,b) in enumerate (B)
113- if b <: ArrayPartition
114- arr_idx = i
115- break
238+ @generated function Base. Broadcast. broadcast_c! (f, :: Type{ArrayPartition} , :: Type ,
239+ dest:: ArrayPartition , as... )
240+ # common number of partitions
241+ N = npartitions (dest, as... )
242+
243+ # broadcast partitions separately
244+ quote
245+ for i in 1 : $ N
246+ broadcast! (f, dest. x[i],
247+ # index partitions
248+ $ ((as[d] <: ArrayPartition ? :(as[$ d]. x[i]) : :(as[$ d])
249+ for d in 1 : length (as)). .. ))
250+ end
251+ dest
116252 end
117- end
118- :(A = similar (B[$ arr_idx]); broadcast! (f,A,B... ); A)
119253end
254+
255+ # # utils
256+
257+ """
258+ build_arraypartition(N::Int, expr::Expr)
259+
260+ Build `ArrayPartition` consisting of `N` partitions, each the result of an evaluation of
261+ `expr` with variable `i` set to the partition index in the range of 1 to `N`.
262+
263+ This can help to write a type-stable method in cases in which the correct return type can
264+ can not be inferred for a simpler implementation with generators.
265+ """
266+ function build_arraypartition (N:: Int , expr:: Expr )
267+ quote
268+ @Base . nexprs $ N i-> (A_i = $ expr)
269+ partitions = @Base . ncall $ N tuple i-> A_i
270+ ArrayPartition (partitions)
271+ end
272+ end
273+
274+ """
275+ npartitions(A...)
276+
277+ Retrieve number of partitions of `ArrayPartitions` in `A...`, or throw an error if there are
278+ `ArrayPartitions` with a different number of partitions.
279+ """
280+ npartitions (A) = 0
281+ npartitions (:: Type{ArrayPartition{T,S}} ) where {T,S} = length (S. parameters)
282+ npartitions (A, B... ) = common_number (npartitions (A), npartitions (B... ))
283+
284+ common_number (a, b) =
285+ a == 0 ? b :
286+ (b == 0 ? a :
287+ (a == b ? a :
288+ throw (DimensionMismatch (" number of partitions must be equal" ))))
289+
290+ """
291+ parameter(::Type{T})
292+
293+ Return type `T` of singleton.
294+ """
295+ parameter (:: Type{T} ) where {T} = T
0 commit comments