Skip to content

Commit 09e4ad1

Browse files
Merge pull request #18 from devmotion/change_arraypartition
Improved ArrayPartition
2 parents 85e6bd7 + d5b657c commit 09e4ad1

File tree

2 files changed

+287
-75
lines changed

2 files changed

+287
-75
lines changed

src/array_partition.jl

Lines changed: 246 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,158 @@
1-
immutable ArrayPartition{T} <: AbstractVector{Any}
2-
x::T
1+
struct ArrayPartition{T,S<:Tuple} <: AbstractVector{T}
2+
x::S
33
end
4+
5+
## constructors
6+
47
ArrayPartition(x...) = ArrayPartition((x...))
5-
function ArrayPartition{T,T2<:Tuple}(x::T2,::Type{Val{T}}=Val{false})
6-
if T
7-
return ArrayPartition{T2}(((copy(a) for a in x)...))
8+
9+
function ArrayPartition(x::S, ::Type{Val{copy}}=Val{false}) where {S<:Tuple,copy}
10+
T = promote_type(eltype.(x)...)
11+
12+
if copy
13+
return ArrayPartition{T,S}(copy.(x))
814
else
9-
return ArrayPartition{T2}((x...))
15+
return ArrayPartition{T,S}(x)
1016
end
1117
end
12-
Base.similar(A::ArrayPartition) = ArrayPartition((similar.(A.x))...)
13-
Base.similar(A::ArrayPartition, dims::Tuple) = ArrayPartition((similar.(A.x))...) # Ignore dims / indices since it's a vector
14-
Base.similar{T}(A::ArrayPartition, ::Type{T}) = ArrayPartition(similar.(A.x, T)...)
15-
Base.similar{T}(A::ArrayPartition, ::Type{T}, dims::Tuple) = ArrayPartition(similar.(A.x, T, dims)...)
1618

17-
Base.zeros(A::ArrayPartition) = ArrayPartition((zeros(x) for x in A.x)...)
18-
Base.zeros(A::ArrayPartition, dims::Tuple) = ArrayPartition((zeros.(A.x))...) # Ignore dims / indices since it's a vector
19-
Base.zeros{T}(A::ArrayPartition, ::Type{T}) = ArrayPartition(zeros.(A.x, T)...)
20-
Base.zeros{T}(A::ArrayPartition, ::Type{T}, dims::Tuple) = ArrayPartition(zeros.(A.x, T, dims)...)
19+
## similar array partitions
2120

22-
Base.copy(A::ArrayPartition) = Base.similar(A)
23-
Base.eltype(A::ArrayPartition) = eltype(A.x[1])
21+
Base.similar(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(similar.(A.x))
2422

25-
# Special to work with units
26-
function Base.ones(A::ArrayPartition)
27-
B = similar(A::ArrayPartition)
28-
for i in eachindex(A.x)
29-
B.x[i] .= eltype(A.x[i])(one(first(A.x[i])))
30-
end
31-
B
32-
end
33-
34-
Base.:+(A::ArrayPartition, B::ArrayPartition) =
35-
ArrayPartition((x .+ y for (x,y) in zip(A.x,B.x))...)
36-
Base.:+(A::Number, B::ArrayPartition) = ArrayPartition((A .+ x for x in B.x)...)
37-
Base.:+(A::ArrayPartition, B::Number) = ArrayPartition((B .+ x for x in A.x)...)
38-
Base.:-(A::ArrayPartition, B::ArrayPartition) =
39-
ArrayPartition((x .- y for (x,y) in zip(A.x,B.x))...)
40-
Base.:-(A::Number, B::ArrayPartition) = ArrayPartition((A .- x for x in B.x)...)
41-
Base.:-(A::ArrayPartition, B::Number) = ArrayPartition((x .- B for x in A.x)...)
42-
Base.:*(A::Number, B::ArrayPartition) = ArrayPartition((A .* x for x in B.x)...)
43-
Base.:*(A::ArrayPartition, B::Number) = ArrayPartition((x .* B for x in A.x)...)
44-
Base.:/(A::ArrayPartition, B::Number) = ArrayPartition((x ./ B for x in A.x)...)
45-
Base.:\(A::Number, B::ArrayPartition) = ArrayPartition((x ./ A for x in B.x)...)
46-
47-
@inline function Base.getindex( A::ArrayPartition,i::Int)
48-
@boundscheck i > length(A) && throw(BoundsError("Index out of bounds"))
23+
# ignore dims since array partitions are vectors
24+
Base.similar(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = similar(A)
25+
26+
# similar array partition of common type
27+
@generated function Base.similar(A::ArrayPartition, ::Type{T}) where {T}
28+
N = npartitions(A)
29+
expr = :(similar(A.x[i], T))
30+
31+
build_arraypartition(N, expr)
32+
end
33+
34+
# ignore dims since array partitions are vectors
35+
Base.similar(A::ArrayPartition, ::Type{T}, dims::NTuple{N,Int}) where {T,N} = similar(A, T)
36+
37+
# similar array partition with different types
38+
@generated function Base.similar(A::ArrayPartition, ::Type{T}, ::Type{S},
39+
R::Vararg{Type}) where {T,S}
40+
N = npartitions(A)
41+
N != length(R) + 2 &&
42+
throw(DimensionMismatch("number of types must be equal to number of partitions"))
43+
44+
types = (T, S, parameter.(R)) # new types
45+
expr = :(similar(A.x[i], ($types)[i]))
46+
47+
build_arraypartition(N, expr)
48+
end
49+
50+
Base.copy(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(copy.(A.x))
51+
52+
## zeros
53+
54+
Base.zeros(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(zeros.(A.x))
55+
56+
# ignore dims since array partitions are vectors
57+
Base.zeros(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = zeros(A)
58+
59+
## ones
60+
61+
# special to work with units
62+
@generated function Base.ones(A::ArrayPartition)
63+
N = npartitions(A)
64+
65+
expr = :(fill!(similar(A.x[i]), oneunit(eltype(A.x[i]))))
66+
67+
build_arraypartition(N, expr)
68+
end
69+
70+
# ignore dims since array partitions are vectors
71+
Base.ones(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = ones(A)
72+
73+
## vector space operations
74+
75+
for op in (:+, :-)
76+
@eval begin
77+
@generated function Base.$op(A::ArrayPartition, B::ArrayPartition)
78+
N = npartitions(A, B)
79+
expr = :($($op).(A.x[i], B.x[i]))
80+
81+
build_arraypartition(N, expr)
82+
end
83+
84+
@generated function Base.$op(A::ArrayPartition, B::Number)
85+
N = npartitions(A)
86+
expr = :($($op).(A.x[i], B))
87+
88+
build_arraypartition(N, expr)
89+
end
90+
91+
@generated function Base.$op(A::Number, B::ArrayPartition)
92+
N = npartitions(B)
93+
expr = :($($op).(A, B.x[i]))
94+
95+
build_arraypartition(N, expr)
96+
end
97+
end
98+
end
99+
100+
for op in (:*, :/)
101+
@eval @generated function Base.$op(A::ArrayPartition, B::Number)
102+
N = npartitions(A)
103+
expr = :($($op).(A.x[i], B))
104+
105+
build_arraypartition(N, expr)
106+
end
107+
end
108+
109+
@generated function Base.:*(A::Number, B::ArrayPartition)
110+
N = npartitions(B)
111+
expr = :((*).(A, B.x[i]))
112+
113+
build_arraypartition(N, expr)
114+
end
115+
116+
@generated function Base.:\(A::Number, B::ArrayPartition)
117+
N = npartitions(B)
118+
expr = :((/).(B.x[i], A))
119+
120+
build_arraypartition(N, expr)
121+
end
122+
123+
## indexing
124+
125+
@inline function Base.getindex(A::ArrayPartition, i::Int)
126+
@boundscheck checkbounds(A, i)
49127
@inbounds for j in 1:length(A.x)
50128
i -= length(A.x[j])
51129
if i <= 0
52130
return A.x[j][length(A.x[j])+i]
53131
end
54132
end
55133
end
56-
Base.getindex( A::ArrayPartition,::Colon) = [A[i] for i in 1:length(A)]
134+
135+
"""
136+
getindex(A::ArrayPartition, i::Int, j...)
137+
138+
Return the entry at index `j...` of the `i`th partition of `A`.
139+
"""
140+
@inline function Base.getindex(A::ArrayPartition, i::Int, j...)
141+
@boundscheck 0 < i <= length(A.x) || throw(BoundsError(A.x, i))
142+
@inbounds b = A.x[i]
143+
@boundscheck checkbounds(b, j...)
144+
@inbounds return b[j...]
145+
end
146+
147+
"""
148+
getindex(A::ArrayPartition, ::Colon)
149+
150+
Return vector with all elements of array partition `A`.
151+
"""
152+
Base.getindex(A::ArrayPartition{T,S}, ::Colon) where {T,S} = T[a for a in Chain(A.x)]
153+
57154
@inline function Base.setindex!(A::ArrayPartition, v, i::Int)
58-
@boundscheck i > length(A) && throw(BoundsError("Index out of bounds"))
155+
@boundscheck checkbounds(A, i)
59156
@inbounds for j in 1:length(A.x)
60157
i -= length(A.x[j])
61158
if i <= 0
@@ -64,28 +161,47 @@ Base.getindex( A::ArrayPartition,::Colon) = [A[i] for i in 1:length(A)]
64161
end
65162
end
66163
end
67-
Base.getindex( A::ArrayPartition, i::Int...) = A.x[i[1]][Base.tail(i)...]
68-
Base.setindex!(A::ArrayPartition, v, i::Int...) = A.x[i[1]][Base.tail(i)...]=v
69164

70-
function recursivecopy!(A::ArrayPartition,B::ArrayPartition)
71-
for (a,b) in zip(A.x,B.x)
72-
copy!(a,b)
165+
"""
166+
setindex!(A::ArrayPartition, v, i::Int, j...)
167+
168+
Set the entry at index `j...` of the `i`th partition of `A` to `v`.
169+
"""
170+
@inline function Base.setindex!(A::ArrayPartition, v, i::Int, j...)
171+
@boundscheck 0 < i <= length(A.x) || throw(BoundsError(A.x, i))
172+
@inbounds b = A.x[i]
173+
@boundscheck checkbounds(b, j...)
174+
@inbounds b[j...] = v
175+
end
176+
177+
## recursive methods
178+
179+
function recursivecopy!(A::ArrayPartition, B::ArrayPartition)
180+
for (a, b) in zip(A.x, B.x)
181+
recursivecopy!(a, b)
73182
end
74183
end
75184

76-
recursive_one(A::ArrayPartition) = recursive_one(first(A.x))
77185
recursive_mean(A::ArrayPartition) = mean((recursive_mean(x) for x in A.x))
78-
Base.zero(A::ArrayPartition) = zero(first(A.x))
79-
Base.first(A::ArrayPartition) = first(first(A.x))
80186

81-
Base.start(A::ArrayPartition) = start(chain(A.x...))
82-
Base.next(A::ArrayPartition,state) = next(chain(A.x...),state)
83-
Base.done(A::ArrayPartition,state) = done(chain(A.x...),state)
187+
# note: consider only first partition for recursive one and eltype
188+
recursive_one(A::ArrayPartition) = recursive_one(first(A.x))
189+
recursive_eltype(A::ArrayPartition) = recursive_eltype(first(A.x))
190+
191+
## iteration
192+
193+
Base.start(A::ArrayPartition) = start(Chain(A.x))
194+
Base.next(A::ArrayPartition,state) = next(Chain(A.x),state)
195+
Base.done(A::ArrayPartition,state) = done(Chain(A.x),state)
84196

85197
Base.length(A::ArrayPartition) = sum((length(x) for x in A.x))
86198
Base.size(A::ArrayPartition) = (length(A),)
87-
Base.isempty(A::ArrayPartition) = (length(A) == 0)
88-
Base.eachindex(A::ArrayPartition) = Base.OneTo(length(A))
199+
200+
# redefine first and last to avoid slow and not type-stable indexing
201+
Base.first(A::ArrayPartition) = first(first(A.x))
202+
Base.last(A::ArrayPartition) = last(last(A.x))
203+
204+
## display
89205

90206
# restore the type rendering in Juno
91207
Juno.@render Juno.Inline x::ArrayPartition begin
@@ -97,23 +213,83 @@ Base.show(io::IO,A::ArrayPartition) = (Base.show.(io,A.x); nothing)
97213
Base.display(A::ArrayPartition) = (println(summary(A));display.(A.x);nothing)
98214
Base.display(io::IO,A::ArrayPartition) = (println(summary(A));display.(io,A.x);nothing)
99215

100-
add_idxs(x,expr) = expr
101-
add_idxs{T<:ArrayPartition}(::Type{T},expr) = :($(expr).x[i])
216+
## broadcasting
217+
218+
Base.Broadcast._containertype(::Type{<:ArrayPartition}) = ArrayPartition
219+
Base.Broadcast.promote_containertype(::Type{ArrayPartition}, ::Type) = ArrayPartition
220+
Base.Broadcast.promote_containertype(::Type, ::Type{ArrayPartition}) = ArrayPartition
221+
Base.Broadcast.promote_containertype(::Type{ArrayPartition}, ::Type{ArrayPartition}) = ArrayPartition
222+
Base.Broadcast.promote_containertype(::Type{ArrayPartition}, ::Type{Array}) = ArrayPartition
223+
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{ArrayPartition}) = ArrayPartition
224+
225+
@generated function Base.Broadcast.broadcast_c(f, ::Type{ArrayPartition}, as...)
226+
# common number of partitions
227+
N = npartitions(as...)
228+
229+
# broadcast partitions separately
230+
expr = :(broadcast(f,
231+
# index partitions
232+
$((as[d] <: ArrayPartition ? :(as[$d].x[i]) : :(as[$d])
233+
for d in 1:length(as))...)))
102234

103-
@generated function Base.broadcast!(f,A::ArrayPartition,B...)
104-
exs = ((add_idxs(B[i],:(B[$i])) for i in eachindex(B))...)
105-
:(for i in eachindex(A.x)
106-
broadcast!(f,A.x[i],$(exs...))
107-
end)
235+
build_arraypartition(N, expr)
108236
end
109237

110-
@generated function Base.broadcast(f,B::Union{Number,ArrayPartition}...)
111-
arr_idx = 0
112-
for (i,b) in enumerate(B)
113-
if b <: ArrayPartition
114-
arr_idx = i
115-
break
238+
@generated function Base.Broadcast.broadcast_c!(f, ::Type{ArrayPartition}, ::Type,
239+
dest::ArrayPartition, as...)
240+
# common number of partitions
241+
N = npartitions(dest, as...)
242+
243+
# broadcast partitions separately
244+
quote
245+
for i in 1:$N
246+
broadcast!(f, dest.x[i],
247+
# index partitions
248+
$((as[d] <: ArrayPartition ? :(as[$d].x[i]) : :(as[$d])
249+
for d in 1:length(as))...))
250+
end
251+
dest
116252
end
117-
end
118-
:(A = similar(B[$arr_idx]); broadcast!(f,A,B...); A)
119253
end
254+
255+
## utils
256+
257+
"""
258+
build_arraypartition(N::Int, expr::Expr)
259+
260+
Build `ArrayPartition` consisting of `N` partitions, each the result of an evaluation of
261+
`expr` with variable `i` set to the partition index in the range of 1 to `N`.
262+
263+
This can help to write a type-stable method in cases in which the correct return type can
264+
can not be inferred for a simpler implementation with generators.
265+
"""
266+
function build_arraypartition(N::Int, expr::Expr)
267+
quote
268+
@Base.nexprs $N i->(A_i = $expr)
269+
partitions = @Base.ncall $N tuple i->A_i
270+
ArrayPartition(partitions)
271+
end
272+
end
273+
274+
"""
275+
npartitions(A...)
276+
277+
Retrieve number of partitions of `ArrayPartitions` in `A...`, or throw an error if there are
278+
`ArrayPartitions` with a different number of partitions.
279+
"""
280+
npartitions(A) = 0
281+
npartitions(::Type{ArrayPartition{T,S}}) where {T,S} = length(S.parameters)
282+
npartitions(A, B...) = common_number(npartitions(A), npartitions(B...))
283+
284+
common_number(a, b) =
285+
a == 0 ? b :
286+
(b == 0 ? a :
287+
(a == b ? a :
288+
throw(DimensionMismatch("number of partitions must be equal"))))
289+
290+
"""
291+
parameter(::Type{T})
292+
293+
Return type `T` of singleton.
294+
"""
295+
parameter(::Type{T}) where {T} = T

0 commit comments

Comments
 (0)