Skip to content

Commit d42c552

Browse files
Merge pull request #69 from nantonel/master
Fix to copyto!(::ArrayPartition,::ArrayPartition) and copyto!(::Array,::ArrayPartition)
2 parents 8bec0d7 + c33674d commit d42c552

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

src/array_partition.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,22 +103,29 @@ end
103103
Base.mapreduce(f,op,A::ArrayPartition) = mapreduce(f,op,(mapreduce(f,op,x) for x in A.x))
104104
Base.any(f,A::ArrayPartition) = any(f,(any(f,x) for x in A.x))
105105
Base.any(f::Function,A::ArrayPartition) = any(f,(any(f,x) for x in A.x))
106-
function Base.copyto!(dest::Array,A::ArrayPartition)
106+
function Base.copyto!(dest::AbstractArray,A::ArrayPartition)
107107
@assert length(dest) == length(A)
108108
cur = 1
109109
@inbounds for i in 1:length(A.x)
110-
dest[cur:(cur+length(A.x[i])-1)] .= A.x[i]
111-
cur += length(A.x[i])
110+
dest[cur:(cur+length(A.x[i])-1)] .= vec(A.x[i])
111+
cur += length(A.x[i])
112112
end
113113
dest
114114
end
115115

116116
function Base.copyto!(A::ArrayPartition,src::ArrayPartition)
117117
@assert length(src) == length(A)
118-
cur = 1
119-
@inbounds for i in 1:length(A.x)
120-
A.x[i] .= @view(src[cur:(cur+length(A.x[i])-1)])
121-
cur += length(A.x[i])
118+
if size.(A.x) == size.(src.x)
119+
A .= src
120+
else
121+
cnt = 0
122+
for i in eachindex(A.x)
123+
x = A.x[i]
124+
for k in eachindex(x)
125+
cnt += 1
126+
x[k] = src[cnt]
127+
end
128+
end
122129
end
123130
A
124131
end

test/partitions_test.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,22 @@ _scalar_op(y) = y + 1
8585
_broadcast_wrapper(y) = _scalar_op.(y)
8686
# Issue #8
8787
# @inferred _broadcast_wrapper(x)
88+
89+
#### testing copyto!
90+
S = [
91+
((1,),(2,)) => ((1,),(2,)),
92+
((3,2),(2,)) => ((3,2),(2,)),
93+
((3,2),(2,)) => ((3,),(3,),(2,))
94+
]
95+
96+
for sizes in S
97+
x = ArrayPartition( randn.(sizes[1]) )
98+
y = ArrayPartition( zeros.(sizes[2]) )
99+
y_array = zeros(length(x))
100+
copyto!(y,x) #testing Base.copyto!(dest::ArrayPartition,A::ArrayPartition)
101+
copyto!(y_array,x) #testing Base.copyto!(dest::Array,A::ArrayPartition)
102+
@test all([x[i] == y[i] for i in eachindex(x)])
103+
@test all([x[i] == y_array[i] for i in eachindex(x)])
104+
end
105+
106+

0 commit comments

Comments
 (0)