Skip to content

Commit 0c3c9d1

Browse files
Merge pull request #163 from mateuszbaran/mbaran/array-partition-broadcasting
Speed up some use cases of `ArrayPartition`
2 parents 49031eb + 5de3bb1 commit 0c3c9d1

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

src/array_partition.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ end
77
ArrayPartition(x...) = ArrayPartition((x...,))
88

99
function ArrayPartition(x::S, ::Type{Val{copy_x}}=Val{false}) where {S<:Tuple,copy_x}
10-
T = promote_type(recursive_bottom_eltype.(x)...)
10+
T = promote_type(map(recursive_bottom_eltype,x)...)
1111
if copy_x
12-
return ArrayPartition{T,S}(copy.(x))
12+
return ArrayPartition{T,S}(map(copy,x))
1313
else
1414
return ArrayPartition{T,S}(x)
1515
end
@@ -81,31 +81,31 @@ end
8181
for op in (:+, :-)
8282
@eval begin
8383
function Base.$op(A::ArrayPartition, B::ArrayPartition)
84-
Base.broadcast($op, A, B)
84+
ArrayPartition(map((x, y)->Base.broadcast($op, x, y), A.x, B.x))
8585
end
8686

8787
function Base.$op(A::ArrayPartition, B::Number)
88-
Base.broadcast($op, A, B)
88+
ArrayPartition(map(y->Base.broadcast($op, y, B), A.x))
8989
end
9090

9191
function Base.$op(A::Number, B::ArrayPartition)
92-
Base.broadcast($op, A, B)
92+
ArrayPartition(map(y->Base.broadcast($op, A, y), B.x))
9393
end
9494
end
9595
end
9696

9797
for op in (:*, :/)
9898
@eval function Base.$op(A::ArrayPartition, B::Number)
99-
Base.broadcast($op, A, B)
99+
ArrayPartition(map(y->Base.broadcast($op, y, B), A.x))
100100
end
101101
end
102102

103103
function Base.:*(A::Number, B::ArrayPartition)
104-
Base.broadcast(*, A, B)
104+
ArrayPartition(map(y->Base.broadcast(*, A, y), B.x))
105105
end
106106

107107
function Base.:\(A::Number, B::ArrayPartition)
108-
Base.broadcast(/, B, A)
108+
ArrayPartition(map(y->Base.broadcast(/, y, A), B.x))
109109
end
110110

111111
Base.:(==)(A::ArrayPartition,B::ArrayPartition) = A.x == B.x
@@ -134,7 +134,7 @@ end
134134
function Base.copyto!(A::ArrayPartition,src::ArrayPartition)
135135
@assert length(src) == length(A)
136136
if size.(A.x) == size.(src.x)
137-
A .= src
137+
map(copyto!, A.x, src.x)
138138
else
139139
cnt = 0
140140
for i in eachindex(A.x)
@@ -281,9 +281,10 @@ end
281281

282282
@inline function Base.copyto!(dest::ArrayPartition, bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where Style
283283
N = npartitions(dest, bc)
284-
@inbounds for i in 1:N
284+
@inline function f(i)
285285
copyto!(dest.x[i], unpack(bc, i))
286286
end
287+
ntuple(f, Val(N))
287288
dest
288289
end
289290

test/partitions_test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ _scalar_op(y) = y + 1
101101
# Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function:
102102
_broadcast_wrapper(y) = _scalar_op.(y)
103103
# Issue #8
104-
# @inferred _broadcast_wrapper(x)
104+
@inferred _broadcast_wrapper(x)
105105

106106
# Testing map
107107
@test map(x->x^2, x) == ArrayPartition(x.x[1].^2, x.x[2].^2)

0 commit comments

Comments
 (0)