Skip to content

Commit 6ba0879

Browse files
copy! for arraypartition
1 parent 96a83a4 commit 6ba0879

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

src/array_partition.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,23 @@ end
124124

125125
Base.mapreduce(f,op,A::ArrayPartition) = mapreduce(f,op,(mapreduce(f,op,x) for x in A.x))
126126
Base.any(f,A::ArrayPartition) = any(f,(any(f,x) for x in A.x))
127+
function Base.copy!(dest::Array,A::ArrayPartition)
128+
@assert length(dest) == length(A)
129+
cur = 1
130+
@inbounds for i in 1:length(A.x)
131+
dest[cur:(cur+length(A.x[i])-1)] .= A.x[i]
132+
cur += length(A.x[i])
133+
end
134+
end
135+
136+
function Base.copy!(A::ArrayPartition,src::ArrayPartition)
137+
@assert length(src) == length(A)
138+
cur = 1
139+
@inbounds for i in 1:length(A.x)
140+
A.x[i] .= @view(src[cur:(cur+length(A.x[i])-1)])
141+
cur += length(A.x[i])
142+
end
143+
end
127144

128145
## indexing
129146

@@ -225,7 +242,7 @@ Base.Broadcast.promote_containertype(::Type{Array}, ::Type{ArrayPartition}) = Ar
225242
N = npartitions(as...)
226243

227244
# broadcast partitions separately
228-
expr = :(broadcast(f,
245+
expr = :(@show "here!"; broadcast(f,
229246
# index partitions
230247
$((as[d] <: ArrayPartition ? :(as[$d].x[i]) : :(as[$d])
231248
for d in 1:length(as))...)))
@@ -250,6 +267,25 @@ end
250267
end
251268
end
252269

270+
@generated function Base.broadcast!(f, ::Type{ArrayPartition}, ::Type,
271+
dest::Array, as...)
272+
# common number of partitions
273+
N = npartitions(dest, as...)
274+
275+
# broadcast partitions separately
276+
quote
277+
@show "here"
278+
@show dest
279+
for i in 1:$N
280+
broadcast!(f, dest.x[i],
281+
# index partitions
282+
$((as[d] <: ArrayPartition ? :(as[$d].x[i]) : :(as[$d])
283+
for d in 1:length(as))...))
284+
end
285+
dest
286+
end
287+
end
288+
253289
## utils
254290

255291
"""

test/partitions_test.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using RecursiveArrayTools, Base.Test
2-
32
A = (rand(5),rand(5))
43
p = ArrayPartition(A)
54
@test (p.x[1][1],p.x[2][1]) == (p[1],p[6])
@@ -27,6 +26,18 @@ p .= (*).(p,a)
2726
p .= (*).(p,p2)
2827
K = (*).(p,p2)
2928

29+
p.*rand(5)
30+
b = rand(10)
31+
c = rand(10)
32+
copy!(b,p)
33+
34+
@test b[1:5] == p.x[1]
35+
@test b[6:10] == p.x[2]
36+
37+
copy!(p,c)
38+
@test c[1:5] == p.x[1]
39+
@test c[6:10] == p.x[2]
40+
3041
## inference tests
3142

3243
x = ArrayPartition([1, 2], [3.0, 4.0])

0 commit comments

Comments
 (0)