Skip to content

Commit b57c6b1

Browse files
Merge pull request #462 from cwittens/master
Improved performance of Broadcasting.
2 parents cb67978 + 71942b9 commit b57c6b1

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

src/array_partition.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -365,14 +365,20 @@ end
365365
end
366366

367367
@inline function Base.copyto!(dest::ArrayPartition,
368-
bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where {
369-
Style,
370-
}
368+
bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where {Style}
371369
N = npartitions(dest, bc)
372-
@inline function f(i)
373-
copyto!(dest.x[i], unpack(bc, i))
370+
# If dest is all the same underlying array type, use for-loop
371+
if all(x isa typeof(first(dest.x)) for x in dest.x)
372+
@inbounds for i in 1:N
373+
copyto!(dest.x[i], unpack(bc, i))
374+
end
375+
else
376+
# Fall back to original implementation for complex broadcasts
377+
@inline function f(i)
378+
copyto!(dest.x[i], unpack(bc, i))
379+
end
380+
ntuple(f, Val(N))
374381
end
375-
ntuple(f, Val(N))
376382
dest
377383
end
378384

@@ -411,8 +417,8 @@ end
411417
i) where {Style <: Broadcast.DefaultArrayStyle}
412418
Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
413419
end
414-
unpack(x, ::Any) = x
415-
unpack(x::ArrayPartition, i) = x.x[i]
420+
@inline unpack(x, ::Any) = x
421+
@inline unpack(x::ArrayPartition, i) = x.x[i]
416422

417423
@inline function unpack_args(i, args::Tuple)
418424
(unpack(args[1], i), unpack_args(i, Base.tail(args))...)

src/named_array_partition.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,9 @@ end
138138
@inline function Base.copyto!(dest::NamedArrayPartition,
139139
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
140140
N = npartitions(dest, bc)
141-
@inline function f(i)
142-
copyto!(ArrayPartition(dest).x[i], unpack(bc, i))
141+
@inbounds for i in 1:N
142+
copyto!(dest.x[i], unpack(bc, i))
143143
end
144-
ntuple(f, Val(N))
145144
return dest
146145
end
147146

0 commit comments

Comments
 (0)