Skip to content

Commit cd8ca95

Browse files
authored
Merge pull request #627 from AkhilAkkapelli/darray-slice
Add `view` Support for Efficient DArray Chunk Slicing
2 parents c745569 + ce99e3b commit cd8ca95

File tree

4 files changed

+222
-4
lines changed

4 files changed

+222
-4
lines changed

docs/src/datadeps.md

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,81 @@ function Dagger.move!(dep_mod::Any, from_space::Dagger.MemorySpace, to_space::Da
220220
return
221221
end
222222
```
223+
224+
## Chunk and DTask slicing with `view`
225+
226+
The `view` function allows you to efficiently create a "view" of a `Chunk` or `DTask` that contains an array. This enables operations on specific parts of your distributed data using standard Julia array slicing, without needing to materialize the entire array.
227+
228+
```julia
229+
view(c::Chunk, slices...) -> ChunkView
230+
view(c::DTask, slices...) -> ChunkView
231+
```
232+
233+
These methods create a `ChunkView` of a `Chunk` or `DTask`, which may be used as an argument to a `Dagger.@spawn` call in a Datadeps region. You specify the desired view using standard Julia array slicing syntax, identical to how you would slice a regular array.
234+
235+
#### Examples
236+
237+
```julia
238+
julia> A = rand(64, 64)
239+
64×64 Matrix{Float64}:
240+
[...]
241+
242+
julia> DA = DArray(A, Blocks(8,8))
243+
64x64 DMatrix{Float64} with 8x8 partitions of size 8x8:
244+
[...]
245+
246+
julia> chunk = DA.chunks[1,1]
247+
DTask (finished)
248+
249+
julia> view(chunk, :, :) # View the entire 8x8 chunk
250+
ChunkSlice{2}(Dagger.Chunk(...), (Colon(), Colon()))
251+
252+
julia> view(chunk, 1:4, 1:4) # View the top-left 4x4 sub-region of the chunk
253+
ChunkSlice{2}(Dagger.Chunk(...), (1:4, 1:4))
254+
255+
julia> view(chunk, 1, :) # View the first row of the chunk
256+
ChunkSlice{2}(Dagger.Chunk(...), (1, Colon()))
257+
258+
julia> view(chunk, :, 5) # View the fifth column of the chunk
259+
ChunkSlice{2}(Dagger.Chunk(...), (Colon(), 5))
260+
261+
julia> view(chunk, 1:2:7, 2:2:8) # View with stepped ranges
262+
ChunkSlice{2}(Dagger.Chunk(...), (1:2:7, 2:2:8))
263+
```
264+
265+
#### Example Usage: Parallel Row Summation of a DArray using `view`
266+
267+
This example demonstrates how to sum multiple rows of a `DArray` by using `view` to process individual rows within chunks to get a vector of row sums.
268+
269+
```julia
270+
julia> A = DArray(rand(10, 1000), Blocks(2, 1000))
271+
10x1000 DMatrix{Float64} with 5x1 partitions of size 2x1000:
272+
[...]
273+
274+
# Helper function to sum a single row and store it in a provided array view
275+
julia> @everywhere function sum_array_row!(row_sum::AbstractArray{Float64}, x::AbstractArray{Float64})
276+
row_sum[1] = sum(x)
277+
end
278+
279+
# Number of rows
280+
julia> nrows = size(A,1)
281+
282+
# Initialize a zero array in the final row sums
283+
julia> row_sums = zeros(nrows)
284+
285+
# Spawn tasks to sum each row in parallel using views
286+
julia> Dagger.spawn_datadeps() do
287+
sz = size(A.chunks,1)
288+
nrows_per_chunk = nrows ÷ sz
289+
for i in 1:sz
290+
for j in 1:nrows_per_chunk
291+
Dagger.@spawn sum_array_row!(Out(view(row_sums, (nrows_per_chunk*(i-1)+j):(nrows_per_chunk*(i-1)+j))),
292+
In(Dagger.view(A.chunks[i,1], j:j, :)))
293+
end
294+
end
295+
end
296+
297+
# Print the result
298+
julia> println("Row sums: ", row_sums)
299+
Row sums: [499.8765, 500.1234, ..., 499.9876]
300+
```

src/datadeps.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,9 @@ end
450450

451451
# Make a copy of each piece of data on each worker
452452
# memory_space => {arg => copy_of_arg}
453+
isremotehandle(x) = false
454+
isremotehandle(x::DTask) = true
455+
isremotehandle(x::Chunk) = true
453456
function generate_slot!(state::DataDepsState, dest_space, data)
454457
if data isa DTask
455458
data = fetch(data; raw=true)
@@ -458,22 +461,25 @@ function generate_slot!(state::DataDepsState, dest_space, data)
458461
to_proc = first(processors(dest_space))
459462
from_proc = first(processors(orig_space))
460463
dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space)
461-
if orig_space == dest_space
464+
if orig_space == dest_space && (data isa Chunk || !isremotehandle(data))
465+
# Fast path for local data or data already in a Chunk
462466
data_chunk = tochunk(data, from_proc)
463467
dest_space_args[data] = data_chunk
464468
@assert processor(data_chunk) in processors(dest_space) || data isa Chunk && processor(data) isa Dagger.OSProc
465469
@assert memory_space(data_chunk) == orig_space
466470
else
467-
w = only(unique(map(get_parent, collect(processors(dest_space))))).pid
471+
to_w = root_worker_id(dest_space)
468472
ctx = Sch.eager_context()
469473
id = rand(Int)
470474
timespan_start(ctx, :move, (;thunk_id=0, id, position=0, processor=to_proc), (;f=nothing, data))
471-
dest_space_args[data] = remotecall_fetch(w, from_proc, to_proc, data) do from_proc, to_proc, data
475+
dest_space_args[data] = remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data
472476
data_converted = move(from_proc, to_proc, data)
473477
data_chunk = tochunk(data_converted, to_proc)
474478
@assert processor(data_chunk) in processors(dest_space)
475479
@assert memory_space(data_converted) == memory_space(data_chunk) "space mismatch! $(memory_space(data_converted)) != $(memory_space(data_chunk)) ($(typeof(data_converted)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)"
476-
@assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)"
480+
if orig_space != dest_space
481+
@assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)"
482+
end
477483
return data_chunk
478484
end
479485
timespan_finish(ctx, :move, (;thunk_id=0, id, position=0, processor=to_proc), (;f=nothing, data=dest_space_args[data]))

src/memory-spaces.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `
122122
memory_spans(x) = memory_spans(aliasing(x))
123123
memory_spans(x, T) = memory_spans(aliasing(x, T))
124124

125+
125126
struct AliasingWrapper <: AbstractAliasing
126127
inner::AbstractAliasing
127128
hash::UInt64
@@ -387,3 +388,71 @@ function will_alias(x_span::MemorySpan, y_span::MemorySpan)
387388
y_end = y_span.ptr + y_span.len - 1
388389
return x_span.ptr <= y_end && y_span.ptr <= x_end
389390
end
391+
392+
struct ChunkView{N}
393+
chunk::Chunk
394+
slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}}
395+
end
396+
397+
function Base.view(c::Chunk, slices...)
398+
if c.domain isa ArrayDomain
399+
nd, sz = ndims(c.domain), size(c.domain)
400+
nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))"))
401+
402+
for (i, s) in enumerate(slices)
403+
if s isa Int
404+
1 s sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))"))
405+
elseif s isa AbstractRange
406+
isempty(s) && continue
407+
1 first(s) last(s) sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))"))
408+
elseif s === Colon()
409+
continue
410+
else
411+
throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon"))
412+
end
413+
end
414+
end
415+
416+
return ChunkView(c, slices)
417+
end
418+
419+
Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...)
420+
421+
function aliasing(x::ChunkView{N}) where N
422+
remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices
423+
x = unwrap(x)
424+
v = view(x, slices...)
425+
return aliasing(v)
426+
end
427+
end
428+
memory_space(x::ChunkView) = memory_space(x.chunk)
429+
isremotehandle(x::ChunkView) = true
430+
431+
#=
432+
function move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::ChunkView, from::ChunkView)
433+
to_w = root_worker_id(to_space)
434+
@assert to_w == myid()
435+
to_raw = unwrap(to.chunk)
436+
from_w = root_worker_id(from_space)
437+
from_raw = to_w == from_w ? unwrap(from.chunk) : remotecall_fetch(f->copy(unwrap(f)), from_w, from.chunk)
438+
from_view = view(from_raw, from.slices...)
439+
to_view = view(to_raw, to.slices...)
440+
move!(dep_mod, to_space, from_space, to_view, from_view)
441+
return
442+
end
443+
=#
444+
445+
function move(from_proc::Processor, to_proc::Processor, slice::ChunkView)
446+
if from_proc == to_proc
447+
return view(unwrap(slice.chunk), slice.slices...)
448+
else
449+
# Need to copy the underlying data, so collapse the view
450+
from_w = root_worker_id(from_proc)
451+
data = remotecall_fetch(from_w, slice.chunk, slice.slices) do chunk, slices
452+
copy(view(unwrap(chunk), slices...))
453+
end
454+
return move(from_proc, to_proc, data)
455+
end
456+
end
457+
458+
Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...)

test/datadeps.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import Dagger: ChunkView, Chunk
12
using LinearAlgebra, Graphs
23

34
@testset "Memory Aliasing" begin
@@ -17,6 +18,70 @@ using LinearAlgebra, Graphs
1718
@test s.sz == sizeof(3)
1819
end
1920

21+
@testset "ChunkView" begin
22+
DA = rand(Blocks(8, 8), 64, 64)
23+
task1 = DA.chunks[1,1]::DTask
24+
chunk1 = fetch(task1; raw=true)::Chunk
25+
v1 = view(chunk1, :, :)
26+
task2 = DA.chunks[1,2]::DTask
27+
chunk2 = fetch(task2; raw=true)::Chunk
28+
v2 = view(chunk2, :, :)
29+
30+
for obj in (chunk1, task1)
31+
@testset "Valid Slices" begin
32+
@test view(obj, :, :) isa ChunkView && view(obj, 1:8, 1:8) isa ChunkView
33+
@test view(obj, 1:2:7, :) isa ChunkView && view(obj, :, 2:2:8) isa ChunkView
34+
@test view(obj, 1, :) isa ChunkView && view(obj, :, 1) isa ChunkView
35+
@test view(obj, 3:3, 5:5) isa ChunkView && view(obj, 5:7, 1:2:4) isa ChunkView
36+
@test view(obj, 8, 8) isa ChunkView
37+
@test view(obj, 1:0, :) isa ChunkView
38+
end
39+
40+
@testset "Dimension Mismatch" begin
41+
@test_throws DimensionMismatch view(obj, :)
42+
@test_throws DimensionMismatch view(obj, :, :, :)
43+
end
44+
45+
@testset "Int Slice Out of Bounds" begin
46+
@test_throws ArgumentError view(obj, 0, :)
47+
@test_throws ArgumentError view(obj, :, 9)
48+
@test_throws ArgumentError view(obj, 9, 1)
49+
end
50+
51+
@testset "Range Slice Out of Bounds" begin
52+
@test_throws ArgumentError view(obj, 0:5, :)
53+
@test_throws ArgumentError view(obj, 1:8, 5:10)
54+
@test_throws ArgumentError view(obj, 2:2:10, :)
55+
end
56+
57+
@testset "Invalid Slice Types" begin
58+
@test_throws DimensionMismatch view(obj, (1:2, :))
59+
@test_throws ArgumentError view(obj, :, [1, 2])
60+
end
61+
end
62+
63+
@test fetch(v1) == fetch(chunk1)
64+
65+
@test Dagger.memory_space(v1) == Dagger.memory_space(chunk1)
66+
@test Dagger.aliasing(v1) isa Dagger.StridedAliasing
67+
ptr = remotecall_fetch(chunk1.handle.owner, chunk1) do chunk
68+
UInt(pointer(Dagger.unwrap(chunk)))
69+
end
70+
@test Dagger.aliasing(v1).base_ptr.addr == ptr
71+
72+
@testset "Aliasing" begin
73+
f! = v1 -> begin
74+
@show typeof(v1) v1
75+
v1 .= 0
76+
return
77+
end
78+
Dagger.spawn_datadeps() do
79+
Dagger.@spawn f!(InOut(v1))
80+
end
81+
@test collect(DA)[1:8, 1:8] == zeros(8, 8)
82+
end
83+
end
84+
2085
function with_logs(f)
2186
Dagger.enable_logging!(;taskdeps=true, taskargs=true)
2287
try

0 commit comments

Comments
 (0)