Skip to content

Commit 67d3602

Browse files
committed
Type stable implementation of ArrayPartition
1 parent 85e6bd7 commit 67d3602

File tree

2 files changed

+267
-71
lines changed

2 files changed

+267
-71
lines changed

src/array_partition.jl

Lines changed: 239 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,158 @@
1-
immutable ArrayPartition{T} <: AbstractVector{Any}
2-
x::T
1+
struct ArrayPartition{T,S<:Tuple} <: AbstractVector{T}
2+
x::S
33
end
4+
5+
## constructors
6+
47
ArrayPartition(x...) = ArrayPartition((x...))
5-
function ArrayPartition{T,T2<:Tuple}(x::T2,::Type{Val{T}}=Val{false})
6-
if T
7-
return ArrayPartition{T2}(((copy(a) for a in x)...))
8+
9+
function ArrayPartition(x::S, ::Type{Val{copy}}=Val{false}) where {S<:Tuple,copy}
10+
T = promote_type(eltype.(x)...)
11+
12+
if copy
13+
return ArrayPartition{T,S}(copy.(x))
814
else
9-
return ArrayPartition{T2}((x...))
15+
return ArrayPartition{T,S}(x)
1016
end
1117
end
12-
Base.similar(A::ArrayPartition) = ArrayPartition((similar.(A.x))...)
13-
Base.similar(A::ArrayPartition, dims::Tuple) = ArrayPartition((similar.(A.x))...) # Ignore dims / indices since it's a vector
14-
Base.similar{T}(A::ArrayPartition, ::Type{T}) = ArrayPartition(similar.(A.x, T)...)
15-
Base.similar{T}(A::ArrayPartition, ::Type{T}, dims::Tuple) = ArrayPartition(similar.(A.x, T, dims)...)
1618

17-
Base.zeros(A::ArrayPartition) = ArrayPartition((zeros(x) for x in A.x)...)
18-
Base.zeros(A::ArrayPartition, dims::Tuple) = ArrayPartition((zeros.(A.x))...) # Ignore dims / indices since it's a vector
19-
Base.zeros{T}(A::ArrayPartition, ::Type{T}) = ArrayPartition(zeros.(A.x, T)...)
20-
Base.zeros{T}(A::ArrayPartition, ::Type{T}, dims::Tuple) = ArrayPartition(zeros.(A.x, T, dims)...)
19+
## similar array partitions
2120

22-
Base.copy(A::ArrayPartition) = Base.similar(A)
23-
Base.eltype(A::ArrayPartition) = eltype(A.x[1])
21+
Base.similar(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(similar.(A.x))
2422

25-
# Special to work with units
26-
function Base.ones(A::ArrayPartition)
27-
B = similar(A::ArrayPartition)
28-
for i in eachindex(A.x)
29-
B.x[i] .= eltype(A.x[i])(one(first(A.x[i])))
30-
end
31-
B
32-
end
33-
34-
Base.:+(A::ArrayPartition, B::ArrayPartition) =
35-
ArrayPartition((x .+ y for (x,y) in zip(A.x,B.x))...)
36-
Base.:+(A::Number, B::ArrayPartition) = ArrayPartition((A .+ x for x in B.x)...)
37-
Base.:+(A::ArrayPartition, B::Number) = ArrayPartition((B .+ x for x in A.x)...)
38-
Base.:-(A::ArrayPartition, B::ArrayPartition) =
39-
ArrayPartition((x .- y for (x,y) in zip(A.x,B.x))...)
40-
Base.:-(A::Number, B::ArrayPartition) = ArrayPartition((A .- x for x in B.x)...)
41-
Base.:-(A::ArrayPartition, B::Number) = ArrayPartition((x .- B for x in A.x)...)
42-
Base.:*(A::Number, B::ArrayPartition) = ArrayPartition((A .* x for x in B.x)...)
43-
Base.:*(A::ArrayPartition, B::Number) = ArrayPartition((x .* B for x in A.x)...)
44-
Base.:/(A::ArrayPartition, B::Number) = ArrayPartition((x ./ B for x in A.x)...)
45-
Base.:\(A::Number, B::ArrayPartition) = ArrayPartition((x ./ A for x in B.x)...)
46-
47-
@inline function Base.getindex( A::ArrayPartition,i::Int)
48-
@boundscheck i > length(A) && throw(BoundsError("Index out of bounds"))
23+
# ignore dims since array partitions are vectors
24+
Base.similar(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = similar(A)
25+
26+
# similar array partition of common type
27+
@generated function Base.similar(A::ArrayPartition, ::Type{T}) where {T}
28+
N = npartitions(A)
29+
expr = :(similar(A.x[i], T))
30+
31+
build_arraypartition(N, expr)
32+
end
33+
34+
# ignore dims since array partitions are vectors
35+
Base.similar(A::ArrayPartition, ::Type{T}, dims::NTuple{N,Int}) where {T,N} = similar(A, T)
36+
37+
# similar array partition with different types
38+
@generated function Base.similar(A::ArrayPartition, ::Type{T}, ::Type{S},
39+
R::Vararg{Type}) where {T,S}
40+
N = npartitions(A)
41+
N != length(R) + 2 &&
42+
throw(DimensionMismatch("number of types must be equal to number of partitions"))
43+
44+
types = (T, S, parameter.(R)) # new types
45+
expr = :(similar(A.x[i], ($types)[i]))
46+
47+
build_arraypartition(N, expr)
48+
end
49+
50+
Base.copy(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(copy.(A.x))
51+
52+
## zeros
53+
54+
Base.zeros(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(zeros.(A.x))
55+
56+
# ignore dims since array partitions are vectors
57+
Base.zeros(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = zeros(A)
58+
59+
## ones
60+
61+
# special to work with units
62+
@generated function Base.ones(A::ArrayPartition)
63+
N = npartitions(A)
64+
65+
expr = :(fill!(similar(A.x[i]), oneunit(eltype(A.x[i]))))
66+
67+
build_arraypartition(N, expr)
68+
end
69+
70+
# ignore dims since array partitions are vectors
71+
Base.ones(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = ones(A)
72+
73+
## vector space operations
74+
75+
for op in (:+, :-)
76+
@eval begin
77+
@generated function Base.$op(A::ArrayPartition, B::ArrayPartition)
78+
N = npartitions(A, B)
79+
expr = :($($op).(A.x[i], B.x[i]))
80+
81+
build_arraypartition(N, expr)
82+
end
83+
84+
@generated function Base.$op(A::ArrayPartition, B::Number)
85+
N = npartitions(A)
86+
expr = :($($op).(A.x[i], B))
87+
88+
build_arraypartition(N, expr)
89+
end
90+
91+
@generated function Base.$op(A::Number, B::ArrayPartition)
92+
N = npartitions(B)
93+
expr = :($($op).(A, B.x[i]))
94+
95+
build_arraypartition(N, expr)
96+
end
97+
end
98+
end
99+
100+
for op in (:*, :/)
101+
@eval @generated function Base.$op(A::ArrayPartition, B::Number)
102+
N = npartitions(A)
103+
expr = :($($op).(A.x[i], B))
104+
105+
build_arraypartition(N, expr)
106+
end
107+
end
108+
109+
@generated function Base.:*(A::Number, B::ArrayPartition)
110+
N = npartitions(B)
111+
expr = :((*).(A, B.x[i]))
112+
113+
build_arraypartition(N, expr)
114+
end
115+
116+
@generated function Base.:\(A::Number, B::ArrayPartition)
117+
N = npartitions(B)
118+
expr = :((/).(B.x[i], A))
119+
120+
build_arraypartition(N, expr)
121+
end
122+
123+
## indexing
124+
125+
@inline function Base.getindex(A::ArrayPartition, i::Int)
126+
@boundscheck checkbounds(A, i)
49127
@inbounds for j in 1:length(A.x)
50128
i -= length(A.x[j])
51129
if i <= 0
52130
return A.x[j][length(A.x[j])+i]
53131
end
54132
end
55133
end
56-
Base.getindex( A::ArrayPartition,::Colon) = [A[i] for i in 1:length(A)]
134+
135+
"""
136+
getindex(A::ArrayPartition, i::Int, j...)
137+
138+
Return the entry at index `j...` of the `i`th partition of `A`.
139+
"""
140+
@inline function Base.getindex(A::ArrayPartition, i::Int, j...)
141+
@boundscheck 0 < i <= length(A.x) || throw(BoundsError(A.x, i))
142+
@inbounds b = A.x[i]
143+
@boundscheck checkbounds(b, j...)
144+
@inbounds return b[j...]
145+
end
146+
147+
"""
148+
getindex(A::ArrayPartition, ::Colon)
149+
150+
Return vector with all elements of array partition `A`.
151+
"""
152+
Base.getindex(A::ArrayPartition{T,S}, ::Colon) where {T,S} = T[a for a in Chain(A.x)]
153+
57154
@inline function Base.setindex!(A::ArrayPartition, v, i::Int)
58-
@boundscheck i > length(A) && throw(BoundsError("Index out of bounds"))
155+
@boundscheck checkbounds(A, i)
59156
@inbounds for j in 1:length(A.x)
60157
i -= length(A.x[j])
61158
if i <= 0
@@ -64,28 +161,41 @@ Base.getindex( A::ArrayPartition,::Colon) = [A[i] for i in 1:length(A)]
64161
end
65162
end
66163
end
67-
Base.getindex( A::ArrayPartition, i::Int...) = A.x[i[1]][Base.tail(i)...]
68-
Base.setindex!(A::ArrayPartition, v, i::Int...) = A.x[i[1]][Base.tail(i)...]=v
69164

70-
function recursivecopy!(A::ArrayPartition,B::ArrayPartition)
71-
for (a,b) in zip(A.x,B.x)
72-
copy!(a,b)
165+
"""
166+
setindex!(A::ArrayPartition, v, i::Int, j...)
167+
168+
Set the entry at index `j...` of the `i`th partition of `A` to `v`.
169+
"""
170+
@inline function Base.setindex!(A::ArrayPartition, v, i::Int, j...)
171+
@boundscheck 0 < i <= length(A.x) || throw(BoundsError(A.x, i))
172+
@inbounds b = A.x[i]
173+
@boundscheck checkbounds(b, j...)
174+
@inbounds b[j...] = v
175+
end
176+
177+
## recursive methods
178+
179+
function recursivecopy!(A::ArrayPartition, B::ArrayPartition)
180+
for (a, b) in zip(A.x, B.x)
181+
recursivecopy!(a, b)
73182
end
74183
end
75184

76185
recursive_one(A::ArrayPartition) = recursive_one(first(A.x))
186+
77187
recursive_mean(A::ArrayPartition) = mean((recursive_mean(x) for x in A.x))
78-
Base.zero(A::ArrayPartition) = zero(first(A.x))
79-
Base.first(A::ArrayPartition) = first(first(A.x))
80188

81-
Base.start(A::ArrayPartition) = start(chain(A.x...))
82-
Base.next(A::ArrayPartition,state) = next(chain(A.x...),state)
83-
Base.done(A::ArrayPartition,state) = done(chain(A.x...),state)
189+
## iteration
190+
191+
Base.start(A::ArrayPartition) = start(Chain(A.x))
192+
Base.next(A::ArrayPartition,state) = next(Chain(A.x),state)
193+
Base.done(A::ArrayPartition,state) = done(Chain(A.x),state)
84194

85195
Base.length(A::ArrayPartition) = sum((length(x) for x in A.x))
86196
Base.size(A::ArrayPartition) = (length(A),)
87-
Base.isempty(A::ArrayPartition) = (length(A) == 0)
88-
Base.eachindex(A::ArrayPartition) = Base.OneTo(length(A))
197+
198+
## display
89199

90200
# restore the type rendering in Juno
91201
Juno.@render Juno.Inline x::ArrayPartition begin
@@ -97,23 +207,83 @@ Base.show(io::IO,A::ArrayPartition) = (Base.show.(io,A.x); nothing)
97207
Base.display(A::ArrayPartition) = (println(summary(A));display.(A.x);nothing)
98208
Base.display(io::IO,A::ArrayPartition) = (println(summary(A));display.(io,A.x);nothing)
99209

100-
add_idxs(x,expr) = expr
101-
add_idxs{T<:ArrayPartition}(::Type{T},expr) = :($(expr).x[i])
210+
## broadcasting
211+
212+
Base.Broadcast._containertype(::Type{<:ArrayPartition}) = ArrayPartition
213+
Base.Broadcast.promote_containertype(::Type{ArrayPartition}, ::Type) = ArrayPartition
214+
Base.Broadcast.promote_containertype(::Type, ::Type{ArrayPartition}) = ArrayPartition
215+
Base.Broadcast.promote_containertype(::Type{ArrayPartition}, ::Type{ArrayPartition}) = ArrayPartition
216+
Base.Broadcast.promote_containertype(::Type{ArrayPartition}, ::Type{Array}) = ArrayPartition
217+
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{ArrayPartition}) = ArrayPartition
218+
219+
@generated function Base.Broadcast.broadcast_c(f, ::Type{ArrayPartition}, as...)
220+
# common number of partitions
221+
N = npartitions(as...)
102222

103-
@generated function Base.broadcast!(f,A::ArrayPartition,B...)
104-
exs = ((add_idxs(B[i],:(B[$i])) for i in eachindex(B))...)
105-
:(for i in eachindex(A.x)
106-
broadcast!(f,A.x[i],$(exs...))
107-
end)
223+
# broadcast partitions separately
224+
expr = :(broadcast(f,
225+
# index partitions
226+
$((as[d] <: ArrayPartition ? :(as[$d].x[i]) : :(as[$d])
227+
for d in 1:length(as))...)))
228+
229+
build_arraypartition(N, expr)
108230
end
109231

110-
@generated function Base.broadcast(f,B::Union{Number,ArrayPartition}...)
111-
arr_idx = 0
112-
for (i,b) in enumerate(B)
113-
if b <: ArrayPartition
114-
arr_idx = i
115-
break
232+
@generated function Base.Broadcast.broadcast_c!(f, ::Type{ArrayPartition}, ::Type,
233+
dest::ArrayPartition, as...)
234+
# common number of partitions
235+
N = npartitions(dest, as...)
236+
237+
# broadcast partitions separately
238+
quote
239+
for i in 1:$N
240+
broadcast!(f, dest.x[i],
241+
# index partitions
242+
$((as[d] <: ArrayPartition ? :(as[$d].x[i]) : :(as[$d])
243+
for d in 1:length(as))...))
244+
end
245+
dest
246+
end
247+
end
248+
249+
## utils
250+
251+
"""
252+
build_arraypartition(N::Int, expr::Expr)
253+
254+
Build `ArrayPartition` consisting of `N` partitions, each the result of an evaluation of
255+
`expr` with variable `i` set to the partition index in the range of 1 to `N`.
256+
257+
This can help to write a type-stable method in cases in which the correct return type can
258+
can not be inferred for a simpler implementation with generators.
259+
"""
260+
function build_arraypartition(N::Int, expr::Expr)
261+
quote
262+
@Base.nexprs $N i->(A_i = $expr)
263+
partitions = @Base.ncall $N tuple i->A_i
264+
ArrayPartition(partitions)
116265
end
117-
end
118-
:(A = similar(B[$arr_idx]); broadcast!(f,A,B...); A)
119266
end
267+
268+
"""
269+
npartitions(A...)
270+
271+
Retrieve number of partitions of `ArrayPartitions` in `A...`, or throw an error if there are
272+
`ArrayPartitions` with a different number of partitions.
273+
"""
274+
npartitions(A) = 0
275+
npartitions(::Type{ArrayPartition{T,S}}) where {T,S} = length(S.parameters)
276+
npartitions(A, B...) = common_number(npartitions(A), npartitions(B...))
277+
278+
common_number(a, b) =
279+
a == 0 ? b :
280+
(b == 0 ? a :
281+
(a == b ? a :
282+
throw(DimensionMismatch("number of partitions must be equal"))))
283+
284+
"""
285+
parameter(::Type{T})
286+
287+
Return type `T` of singleton.
288+
"""
289+
parameter(::Type{T}) where {T} = T

test/partitions_test.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,38 @@ p .= (*).(p,a)
2424
p .= (*).(p,p2)
2525
K = (*).(p,p2)
2626

27+
## inference tests
28+
2729
x = ArrayPartition([1, 2], [3.0, 4.0])
30+
31+
# similar partitions
2832
@inferred(similar(x))
2933
@inferred(similar(x, (2, 2)))
30-
@test_broken @inferred(similar(x, Int, (2, 2)))
31-
@test_broken @inferred(similar(x, (Int, Float64), (2, 2)))
34+
@inferred(similar(x, Int))
35+
@inferred(similar(x, Int, (2, 2)))
36+
@inferred(similar(x, Int, Float64))
37+
38+
# zeros
39+
@inferred(zeros(x))
40+
@inferred(zeros(x, (2,2)))
41+
42+
# ones
43+
@inferred(ones(x))
44+
@inferred(ones(x, (2,2)))
45+
46+
# vector space calculations
47+
@inferred(x+5)
48+
@inferred(5+x)
49+
@inferred(x-5)
50+
@inferred(5-x)
51+
@inferred(x*5)
52+
@inferred(5*x)
53+
@inferred(x/5)
54+
@inferred(5\x)
55+
@inferred(x+x)
56+
@inferred(x-x)
3257

58+
# broadcasting
3359
_scalar_op(y) = y + 1
3460
# Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function:
3561
_broadcast_wrapper(y) = _scalar_op.(y)

0 commit comments

Comments
 (0)