Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions docs/src/datadeps.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,81 @@ function Dagger.move!(dep_mod::Any, from_space::Dagger.MemorySpace, to_space::Da
return
end
```

## Chunk and DTask slicing with `view`

The `view` function allows you to efficiently create a "view" of a `Chunk` or `DTask` that contains an array. This enables operations on specific parts of your distributed data using standard Julia array slicing, without needing to materialize the entire array.

```julia
view(c::Chunk, slices...) -> ChunkView
view(c::DTask, slices...) -> ChunkView
```

These methods create a `ChunkView` of a `Chunk` or `DTask`, which may be used as an argument to a `Dagger.@spawn` call in a Datadeps region. You specify the desired view using standard Julia array slicing syntax, identical to how you would slice a regular array.

#### Examples

```julia
julia> A = rand(64, 64)
64×64 Matrix{Float64}:
[...]

julia> DA = DArray(A, Blocks(8,8))
64x64 DMatrix{Float64} with 8x8 partitions of size 8x8:
[...]

julia> chunk = DA.chunks[1,1]
DTask (finished)

julia> view(chunk, :, :) # View the entire 8x8 chunk
ChunkSlice{2}(Dagger.Chunk(...), (Colon(), Colon()))

julia> view(chunk, 1:4, 1:4) # View the top-left 4x4 sub-region of the chunk
ChunkSlice{2}(Dagger.Chunk(...), (1:4, 1:4))

julia> view(chunk, 1, :) # View the first row of the chunk
ChunkSlice{2}(Dagger.Chunk(...), (1, Colon()))

julia> view(chunk, :, 5) # View the fifth column of the chunk
ChunkSlice{2}(Dagger.Chunk(...), (Colon(), 5))

julia> view(chunk, 1:2:7, 2:2:8) # View with stepped ranges
ChunkSlice{2}(Dagger.Chunk(...), (1:2:7, 2:2:8))
```

#### Example Usage: Parallel Row Summation of a DArray using `view`

This example demonstrates how to sum multiple rows of a `DArray` by using `view` to process individual rows within chunks to get a vector of row sums.

```julia
julia> A = DArray(rand(10, 1000), Blocks(2, 1000))
10x1000 DMatrix{Float64} with 5x1 partitions of size 2x1000:
[...]

# Helper function to sum a single row and store it in a provided array view
julia> @everywhere function sum_array_row!(row_sum::AbstractArray{Float64}, x::AbstractArray{Float64})
row_sum[1] = sum(x)
end

# Number of rows
julia> nrows = size(A,1)

# Initialize a zero array in the final row sums
julia> row_sums = zeros(nrows)

# Spawn tasks to sum each row in parallel using views
julia> Dagger.spawn_datadeps() do
sz = size(A.chunks,1)
nrows_per_chunk = nrows ÷ sz
for i in 1:sz
for j in 1:nrows_per_chunk
Dagger.@spawn sum_array_row!(Out(view(row_sums, (nrows_per_chunk*(i-1)+j):(nrows_per_chunk*(i-1)+j))),
In(Dagger.view(A.chunks[i,1], j:j, :)))
end
end
end

# Print the result
julia> println("Row sums: ", row_sums)
Row sums: [499.8765, 500.1234, ..., 499.9876]
```
14 changes: 10 additions & 4 deletions src/datadeps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,9 @@ end

# Make a copy of each piece of data on each worker
# memory_space => {arg => copy_of_arg}
isremotehandle(x) = false
isremotehandle(x::DTask) = true
isremotehandle(x::Chunk) = true
function generate_slot!(state::DataDepsState, dest_space, data)
if data isa DTask
data = fetch(data; raw=true)
Expand All @@ -458,22 +461,25 @@ function generate_slot!(state::DataDepsState, dest_space, data)
to_proc = first(processors(dest_space))
from_proc = first(processors(orig_space))
dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space)
if orig_space == dest_space
if orig_space == dest_space && (data isa Chunk || !isremotehandle(data))
# Fast path for local data or data already in a Chunk
data_chunk = tochunk(data, from_proc)
dest_space_args[data] = data_chunk
@assert processor(data_chunk) in processors(dest_space) || data isa Chunk && processor(data) isa Dagger.OSProc
@assert memory_space(data_chunk) == orig_space
else
w = only(unique(map(get_parent, collect(processors(dest_space))))).pid
to_w = root_worker_id(dest_space)
ctx = Sch.eager_context()
id = rand(Int)
timespan_start(ctx, :move, (;thunk_id=0, id, position=0, processor=to_proc), (;f=nothing, data))
dest_space_args[data] = remotecall_fetch(w, from_proc, to_proc, data) do from_proc, to_proc, data
dest_space_args[data] = remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data
data_converted = move(from_proc, to_proc, data)
data_chunk = tochunk(data_converted, to_proc)
@assert processor(data_chunk) in processors(dest_space)
@assert memory_space(data_converted) == memory_space(data_chunk) "space mismatch! $(memory_space(data_converted)) != $(memory_space(data_chunk)) ($(typeof(data_converted)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)"
@assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)"
if orig_space != dest_space
@assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)"
end
return data_chunk
end
timespan_finish(ctx, :move, (;thunk_id=0, id, position=0, processor=to_proc), (;f=nothing, data=dest_space_args[data]))
Expand Down
69 changes: 69 additions & 0 deletions src/memory-spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `
memory_spans(x) = memory_spans(aliasing(x))
memory_spans(x, T) = memory_spans(aliasing(x, T))


struct AliasingWrapper <: AbstractAliasing
inner::AbstractAliasing
hash::UInt64
Expand Down Expand Up @@ -387,3 +388,71 @@ function will_alias(x_span::MemorySpan, y_span::MemorySpan)
y_end = y_span.ptr + y_span.len - 1
return x_span.ptr <= y_end && y_span.ptr <= x_end
end

struct ChunkView{N}
chunk::Chunk
slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}}
end

function Base.view(c::Chunk, slices...)
if c.domain isa ArrayDomain
nd, sz = ndims(c.domain), size(c.domain)
nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))"))

for (i, s) in enumerate(slices)
if s isa Int
1 ≤ s ≤ sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))"))
elseif s isa AbstractRange
isempty(s) && continue
1 ≤ first(s) ≤ last(s) ≤ sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))"))
elseif s === Colon()
continue
else
throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon"))
end
end
end

return ChunkView(c, slices)
end

Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...)

function aliasing(x::ChunkView{N}) where N
remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices
x = unwrap(x)
v = view(x, slices...)
return aliasing(v)
end
end
memory_space(x::ChunkView) = memory_space(x.chunk)
isremotehandle(x::ChunkView) = true

#=
function move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::ChunkView, from::ChunkView)
to_w = root_worker_id(to_space)
@assert to_w == myid()
to_raw = unwrap(to.chunk)
from_w = root_worker_id(from_space)
from_raw = to_w == from_w ? unwrap(from.chunk) : remotecall_fetch(f->copy(unwrap(f)), from_w, from.chunk)
from_view = view(from_raw, from.slices...)
to_view = view(to_raw, to.slices...)
move!(dep_mod, to_space, from_space, to_view, from_view)
return
end
=#

function move(from_proc::Processor, to_proc::Processor, slice::ChunkView)
if from_proc == to_proc
return view(unwrap(slice.chunk), slice.slices...)
else
# Need to copy the underlying data, so collapse the view
from_w = root_worker_id(from_proc)
data = remotecall_fetch(from_w, slice.chunk, slice.slices) do chunk, slices
copy(view(unwrap(chunk), slices...))
end
return move(from_proc, to_proc, data)
end
end

Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...)
65 changes: 65 additions & 0 deletions test/datadeps.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Dagger: ChunkView, Chunk
using LinearAlgebra, Graphs

@testset "Memory Aliasing" begin
Expand All @@ -17,6 +18,70 @@ using LinearAlgebra, Graphs
@test s.sz == sizeof(3)
end

@testset "ChunkView" begin
DA = rand(Blocks(8, 8), 64, 64)
task1 = DA.chunks[1,1]::DTask
chunk1 = fetch(task1; raw=true)::Chunk
v1 = view(chunk1, :, :)
task2 = DA.chunks[1,2]::DTask
chunk2 = fetch(task2; raw=true)::Chunk
v2 = view(chunk2, :, :)

for obj in (chunk1, task1)
@testset "Valid Slices" begin
@test view(obj, :, :) isa ChunkView && view(obj, 1:8, 1:8) isa ChunkView
@test view(obj, 1:2:7, :) isa ChunkView && view(obj, :, 2:2:8) isa ChunkView
@test view(obj, 1, :) isa ChunkView && view(obj, :, 1) isa ChunkView
@test view(obj, 3:3, 5:5) isa ChunkView && view(obj, 5:7, 1:2:4) isa ChunkView
@test view(obj, 8, 8) isa ChunkView
@test view(obj, 1:0, :) isa ChunkView
end

@testset "Dimension Mismatch" begin
@test_throws DimensionMismatch view(obj, :)
@test_throws DimensionMismatch view(obj, :, :, :)
end

@testset "Int Slice Out of Bounds" begin
@test_throws ArgumentError view(obj, 0, :)
@test_throws ArgumentError view(obj, :, 9)
@test_throws ArgumentError view(obj, 9, 1)
end

@testset "Range Slice Out of Bounds" begin
@test_throws ArgumentError view(obj, 0:5, :)
@test_throws ArgumentError view(obj, 1:8, 5:10)
@test_throws ArgumentError view(obj, 2:2:10, :)
end

@testset "Invalid Slice Types" begin
@test_throws DimensionMismatch view(obj, (1:2, :))
@test_throws ArgumentError view(obj, :, [1, 2])
end
end

@test fetch(v1) == fetch(chunk1)

@test Dagger.memory_space(v1) == Dagger.memory_space(chunk1)
@test Dagger.aliasing(v1) isa Dagger.StridedAliasing
ptr = remotecall_fetch(chunk1.handle.owner, chunk1) do chunk
UInt(pointer(Dagger.unwrap(chunk)))
end
@test Dagger.aliasing(v1).base_ptr.addr == ptr

@testset "Aliasing" begin
f! = v1 -> begin
@show typeof(v1) v1
v1 .= 0
return
end
Dagger.spawn_datadeps() do
Dagger.@spawn f!(InOut(v1))
end
@test collect(DA)[1:8, 1:8] == zeros(8, 8)
end
end

function with_logs(f)
Dagger.enable_logging!(;taskdeps=true, taskargs=true)
try
Expand Down