Skip to content

Commit b08a940

Browse files
committed
new copyto! function
1 parent f0f51c3 commit b08a940

File tree

2 files changed

+5
-22
lines changed

2 files changed

+5
-22
lines changed

src/array_partition.jl

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -364,28 +364,13 @@ end
364364
ArrayPartition(f, N)
365365
end
366366

367-
# old version
368-
# @inline function Base.copyto!(dest::ArrayPartition,
369-
# bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where {
370-
# Style,
371-
# }
372-
# N = npartitions(dest, bc)
373-
# @inline function f(i)
374-
# copyto!(dest.x[i], unpack(bc, i))
375-
# end
376-
# ntuple(f, Val(N))
377-
# dest
378-
# end
379-
380-
# new version
381367
@inline function Base.copyto!(dest::ArrayPartition,
382368
bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where {Style}
383369
N = npartitions(dest, bc)
384-
# Check if this is a simple enough broadcast that we can optimize
385-
if bc.f isa Union{typeof(+), typeof(*), typeof(muladd)}
370+
# If dest is all the same underlying array type, use for-loop
371+
if all(x isa typeof(first(dest.x)) for x in dest.x)
386372
@inbounds for i in 1:N
387-
# Use materialize! which is more efficient than copyto! for simple broadcasts
388-
Base.Broadcast.materialize!(dest.x[i], unpack(bc, i))
373+
copyto!(dest.x[i], unpack(bc, i))
389374
end
390375
else
391376
# Fall back to original implementation for complex broadcasts

src/named_array_partition.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,12 @@ end
135135
NamedArrayPartition(f, N, getfield(x, :names_to_indices))
136136
end
137137

138-
# TODO: has this also performance problems and can be improved?
139138
@inline function Base.copyto!(dest::NamedArrayPartition,
140139
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
141140
N = npartitions(dest, bc)
142-
@inline function f(i)
143-
copyto!(ArrayPartition(dest).x[i], unpack(bc, i))
141+
@inbounds for i in 1:N
142+
copyto!(dest.x[i], unpack(bc, i))
144143
end
145-
ntuple(f, Val(N))
146144
return dest
147145
end
148146

0 commit comments

Comments
 (0)