Skip to content

Commit 969155b

Browse files
committed
DArray: Add copyto! and copy buffering helpers
1 parent 6c17cce commit 969155b

File tree

3 files changed

+113
-0
lines changed

3 files changed

+113
-0
lines changed

docs/src/darray.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ This list is not exhaustive, but documents operations which are known to work we
217217

218218
From `Base`:
219219
- Broadcasting
220+
- `similar`/`copy`/`copyto!`
220221
- `map`/`reduce`/`mapreduce`
221222
- `sum`/`prod`
222223
- `minimum`/`maximum`/`extrema`

src/Dagger.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ include("datadeps.jl")
6060
include("array/darray.jl")
6161
include("array/alloc.jl")
6262
include("array/map-reduce.jl")
63+
include("array/copy.jl")
6364

6465
# File IO
6566
include("file-io.jl")

src/array/copy.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copy Buffering
2+
3+
function maybe_copy_buffered(f, args...)
4+
@assert all(arg->arg isa Pair{<:DArray,<:Blocks}, args) "maybe_copy_buffered only supports `DArray`=>`Blocks`"
5+
if any(arg_part->arg_part[1].partitioning != arg_part[2], args)
6+
return copy_buffered(f, args...)
7+
else
8+
return f(map(first, args)...)
9+
end
10+
end
11+
function copy_buffered(f, args...)
12+
real_args = map(arg_part->arg_part[1], args)
13+
buffered_args = map(arg_part->allocate_copy_buffer(arg_part[2], arg_part[1]), args)
14+
for (buf_arg, arg) in zip(buffered_args, real_args)
15+
copyto!(buf_arg, arg)
16+
end
17+
result = f(buffered_args...)
18+
for (buf_arg, arg) in zip(buffered_args, real_args)
19+
copyto!(arg, buf_arg)
20+
end
21+
return result
22+
end
23+
function allocate_copy_buffer(part::Blocks{N}, A::DArray{T,N}) where {T,N}
24+
# FIXME: undef initializer
25+
return zeros(part, T, size(A))
26+
end
27+
function Base.copyto!(B::DArray{T,N}, A::DArray{T,N}) where {T,N}
28+
if size(B) != size(A)
29+
throw(DimensionMismatch("Cannot copy from array of size $(size(A)) to array of size $(size(B))"))
30+
end
31+
32+
Bc = B.chunks
33+
Ac = A.chunks
34+
Asd_all = A.subdomains::DomainBlocks{N}
35+
36+
Dagger.spawn_datadeps() do
37+
for Bidx in CartesianIndices(Bc)
38+
Bpart = Bc[Bidx]
39+
Bsd = B.subdomains[Bidx]
40+
41+
# Find the first overlapping subdomain of A
42+
if A.partitioning isa Blocks
43+
Aidx = CartesianIndex(ntuple(i->fld1(Bsd.indexes[i].start, A.partitioning.blocksize[i]), N))
44+
else
45+
# Fallback just in case of non-dense partitioning
46+
Aidx = first(CartesianIndices(Ac))
47+
Asd = first(Asd_all)
48+
for dim in 1:N
49+
while Asd.indexes[dim].stop < Bsd.indexes[dim].start
50+
Aidx += CartesianIndex(ntuple(i->i==dim, N))
51+
Asd = Asd_all[Aidx]
52+
end
53+
end
54+
end
55+
Aidx_start = Aidx
56+
57+
# Find the last overlapping subdomain of A
58+
for dim in 1:N
59+
while true
60+
Aidx_next = Aidx + CartesianIndex(ntuple(i->i==dim, N))
61+
if !(Aidx_next in CartesianIndices(Ac))
62+
break
63+
end
64+
Asd_next = Asd_all[Aidx_next]
65+
if Asd_next.indexes[dim].start <= Bsd.indexes[dim].stop
66+
Aidx = Aidx_next
67+
else
68+
break
69+
end
70+
end
71+
end
72+
Aidx_end = Aidx
73+
74+
# Find the span and set of subdomains of A overlapping Bpart
75+
Aidx_span = Aidx_start:Aidx_end
76+
Asd_view = view(A.subdomains, Aidx_span)
77+
78+
# Copy all overlapping subdomains of A
79+
for Aidx in Aidx_span
80+
Asd = Asd_all[Aidx]
81+
Apart = Ac[Aidx]
82+
83+
# Compute the true range
84+
range_start = CartesianIndex(ntuple(i->max(Bsd.indexes[i].start, Asd.indexes[i].start), N))
85+
range_end = CartesianIndex(ntuple(i->min(Bsd.indexes[i].stop, Asd.indexes[i].stop), N))
86+
range_diff = range_end - range_start
87+
88+
# Compute the offset range into Apart
89+
Asd_start = ntuple(i->Asd.indexes[i].start, N)
90+
Asd_end = ntuple(i->Asd.indexes[i].stop, N)
91+
Arange = range(range_start - CartesianIndex(Asd_start) + CartesianIndex{N}(1),
92+
range_start - CartesianIndex(Asd_start) + CartesianIndex{N}(1) + range_diff)
93+
94+
# Compute the offset range into Bpart
95+
Bsd_start = ntuple(i->Bsd.indexes[i].start, N)
96+
Bsd_end = ntuple(i->Bsd.indexes[i].stop, N)
97+
Brange = range(range_start - CartesianIndex(Bsd_start) + CartesianIndex{N}(1),
98+
range_start - CartesianIndex(Bsd_start) + CartesianIndex{N}(1) + range_diff)
99+
100+
# Perform view copy
101+
Dagger.@spawn copyto_view!(Out(Bpart), Brange, In(Apart), Arange)
102+
end
103+
end
104+
end
105+
106+
return B
107+
end
108+
function copyto_view!(Bpart, Brange, Apart, Arange)
109+
copyto!(view(Bpart, Brange), view(Apart, Arange))
110+
return
111+
end

0 commit comments

Comments
 (0)