Skip to content

Commit 0b90d37

Browse files
committed
fix to copyto!
1 parent 8bec0d7 commit 0b90d37

File tree

3 files changed

+55
-6
lines changed

3 files changed

+55
-6
lines changed

src/array_partition.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,25 @@ function Base.copyto!(dest::Array,A::ArrayPartition)
107107
@assert length(dest) == length(A)
108108
cur = 1
109109
@inbounds for i in 1:length(A.x)
110-
dest[cur:(cur+length(A.x[i])-1)] .= A.x[i]
111-
cur += length(A.x[i])
110+
dest[cur:(cur+length(A.x[i])-1)] .= vec(A.x[i])
111+
cur += length(A.x[i])
112112
end
113113
dest
114114
end
115115

116116
function Base.copyto!(A::ArrayPartition,src::ArrayPartition)
117117
@assert length(src) == length(A)
118-
cur = 1
119-
@inbounds for i in 1:length(A.x)
120-
A.x[i] .= @view(src[cur:(cur+length(A.x[i])-1)])
121-
cur += length(A.x[i])
118+
if size.(A.x) == size.(src.x)
119+
A .= src
120+
else
121+
cnt = 0
122+
for i in eachindex(A.x)
123+
x = A.x[i]
124+
for k in eachindex(x)
125+
cnt += 1
126+
x[k] = src[cnt]
127+
end
128+
end
122129
end
123130
A
124131
end

test.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using RecursiveArrayTools, BenchmarkTools
2+
3+
function Base.copyto!(A::ArrayPartition,src::ArrayPartition)
4+
@assert length(src) == length(A)
5+
if size.(A.x) == size.(src.x)
6+
A .= src
7+
else
8+
cnt = 0
9+
for i in eachindex(A.x)
10+
x = A.x[i]
11+
for k in eachindex(x)
12+
cnt += 1
13+
x[k] = src[cnt]
14+
end
15+
end
16+
end
17+
A
18+
end
19+
20+
x = ArrayPartition(randn(1000,10),randn(1000),randn(3));
21+
y =zero(x);
22+
@btime copyto!($y,$x);
23+
@btime $y .= $x;

test/partitions_test.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,22 @@ _scalar_op(y) = y + 1
8585
_broadcast_wrapper(y) = _scalar_op.(y)
8686
# Issue #8
8787
# @inferred _broadcast_wrapper(x)
88+
89+
#### testing copyto!
90+
S = [
91+
((1,),(2,)) => ((1,),(2,)),
92+
((3,2),(2,)) => ((3,2),(2,)),
93+
((3,2),(2,)) => ((3,),(3,),(2,))
94+
]
95+
96+
for sizes in S
97+
x = ArrayPartition( randn.(sizes[1]) )
98+
y = ArrayPartition( zeros.(sizes[2]) )
99+
y_array = zeros(length(x))
100+
copyto!(y,x) #testing Base.copyto!(dest::ArrayPartition,A::ArrayPartition)
101+
copyto!(y_array,x) #testing Base.copyto!(dest::Array,A::ArrayPartition)
102+
@test all([x[i] == y[i] for i in eachindex(x)])
103+
@test all([x[i] == y_array[i] for i in eachindex(x)])
104+
end
105+
106+

0 commit comments

Comments
 (0)