diff --git a/docs/src/datadeps.md b/docs/src/datadeps.md index f06f36c9d..edb7b3b0e 100644 --- a/docs/src/datadeps.md +++ b/docs/src/datadeps.md @@ -220,3 +220,81 @@ function Dagger.move!(dep_mod::Any, from_space::Dagger.MemorySpace, to_space::Da return end ``` + +## Chunk and DTask slicing with `view` + +The `view` function allows you to efficiently create a "view" of a `Chunk` or `DTask` that contains an array. This enables operations on specific parts of your distributed data using standard Julia array slicing, without needing to materialize the entire array. + +```julia + view(c::Chunk, slices...) -> ChunkView + view(c::DTask, slices...) -> ChunkView +``` + +These methods create a `ChunkView` of a `Chunk` or `DTask`, which may be used as an argument to a `Dagger.@spawn` call in a Datadeps region. You specify the desired view using standard Julia array slicing syntax, identical to how you would slice a regular array. + +#### Examples + +```julia +julia> A = rand(64, 64) +64×64 Matrix{Float64}: +[...] + +julia> DA = DArray(A, Blocks(8,8)) +64x64 DMatrix{Float64} with 8x8 partitions of size 8x8: +[...] + +julia> chunk = DA.chunks[1,1] +DTask (finished) + +julia> view(chunk, :, :) # View the entire 8x8 chunk +ChunkSlice{2}(Dagger.Chunk(...), (Colon(), Colon())) + +julia> view(chunk, 1:4, 1:4) # View the top-left 4x4 sub-region of the chunk +ChunkSlice{2}(Dagger.Chunk(...), (1:4, 1:4)) + +julia> view(chunk, 1, :) # View the first row of the chunk +ChunkSlice{2}(Dagger.Chunk(...), (1, Colon())) + +julia> view(chunk, :, 5) # View the fifth column of the chunk +ChunkSlice{2}(Dagger.Chunk(...), (Colon(), 5)) + +julia> view(chunk, 1:2:7, 2:2:8) # View with stepped ranges +ChunkSlice{2}(Dagger.Chunk(...), (1:2:7, 2:2:8)) +``` + +#### Example Usage: Parallel Row Summation of a DArray using `view` + +This example demonstrates how to sum multiple rows of a `DArray` by using `view` to process individual rows within chunks to get a vector of row sums. + +```julia +julia> A = DArray(rand(10, 1000), Blocks(2, 1000)) +10x1000 DMatrix{Float64} with 5x1 partitions of size 2x1000: +[...] + +# Helper function to sum a single row and store it in a provided array view +julia> @everywhere function sum_array_row!(row_sum::AbstractArray{Float64}, x::AbstractArray{Float64}) + row_sum[1] = sum(x) +end + +# Number of rows +julia> nrows = size(A,1) + +# Initialize a zero array in the final row sums +julia> row_sums = zeros(nrows) + +# Spawn tasks to sum each row in parallel using views +julia> Dagger.spawn_datadeps() do + sz = size(A.chunks,1) + nrows_per_chunk = nrows ÷ sz + for i in 1:sz + for j in 1:nrows_per_chunk + Dagger.@spawn sum_array_row!(Out(view(row_sums, (nrows_per_chunk*(i-1)+j):(nrows_per_chunk*(i-1)+j))), + In(Dagger.view(A.chunks[i,1], j:j, :))) + end + end + end + +# Print the result +julia> println("Row sums: ", row_sums) +Row sums: [499.8765, 500.1234, ..., 499.9876] +``` diff --git a/src/datadeps.jl b/src/datadeps.jl index 4642a8790..a4f64d2bd 100644 --- a/src/datadeps.jl +++ b/src/datadeps.jl @@ -450,6 +450,9 @@ end # Make a copy of each piece of data on each worker # memory_space => {arg => copy_of_arg} +isremotehandle(x) = false +isremotehandle(x::DTask) = true +isremotehandle(x::Chunk) = true function generate_slot!(state::DataDepsState, dest_space, data) if data isa DTask data = fetch(data; raw=true) @@ -458,22 +461,25 @@ function generate_slot!(state::DataDepsState, dest_space, data) to_proc = first(processors(dest_space)) from_proc = first(processors(orig_space)) dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - if orig_space == dest_space + if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) + # Fast path for local data or data already in a Chunk data_chunk = tochunk(data, from_proc) dest_space_args[data] = data_chunk @assert processor(data_chunk) in processors(dest_space) || data isa Chunk && processor(data) isa Dagger.OSProc @assert memory_space(data_chunk) == orig_space else - w = only(unique(map(get_parent, collect(processors(dest_space))))).pid + to_w = root_worker_id(dest_space) ctx = Sch.eager_context() id = rand(Int) timespan_start(ctx, :move, (;thunk_id=0, id, position=0, processor=to_proc), (;f=nothing, data)) - dest_space_args[data] = remotecall_fetch(w, from_proc, to_proc, data) do from_proc, to_proc, data + dest_space_args[data] = remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data data_converted = move(from_proc, to_proc, data) data_chunk = tochunk(data_converted, to_proc) @assert processor(data_chunk) in processors(dest_space) @assert memory_space(data_converted) == memory_space(data_chunk) "space mismatch! $(memory_space(data_converted)) != $(memory_space(data_chunk)) ($(typeof(data_converted)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - @assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" + if orig_space != dest_space + @assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" + end return data_chunk end timespan_finish(ctx, :move, (;thunk_id=0, id, position=0, processor=to_proc), (;f=nothing, data=dest_space_args[data])) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 0ef0d1200..090c0a14f 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -122,6 +122,7 @@ memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define ` memory_spans(x) = memory_spans(aliasing(x)) memory_spans(x, T) = memory_spans(aliasing(x, T)) + struct AliasingWrapper <: AbstractAliasing inner::AbstractAliasing hash::UInt64 @@ -387,3 +388,71 @@ function will_alias(x_span::MemorySpan, y_span::MemorySpan) y_end = y_span.ptr + y_span.len - 1 return x_span.ptr <= y_end && y_span.ptr <= x_end end + +struct ChunkView{N} + chunk::Chunk + slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} +end + +function Base.view(c::Chunk, slices...) + if c.domain isa ArrayDomain + nd, sz = ndims(c.domain), size(c.domain) + nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))")) + + for (i, s) in enumerate(slices) + if s isa Int + 1 ≤ s ≤ sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))")) + elseif s isa AbstractRange + isempty(s) && continue + 1 ≤ first(s) ≤ last(s) ≤ sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))")) + elseif s === Colon() + continue + else + throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon")) + end + end + end + + return ChunkView(c, slices) +end + +Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) + +function aliasing(x::ChunkView{N}) where N + remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices + x = unwrap(x) + v = view(x, slices...) + return aliasing(v) + end +end +memory_space(x::ChunkView) = memory_space(x.chunk) +isremotehandle(x::ChunkView) = true + +#= +function move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::ChunkView, from::ChunkView) + to_w = root_worker_id(to_space) + @assert to_w == myid() + to_raw = unwrap(to.chunk) + from_w = root_worker_id(from_space) + from_raw = to_w == from_w ? unwrap(from.chunk) : remotecall_fetch(f->copy(unwrap(f)), from_w, from.chunk) + from_view = view(from_raw, from.slices...) + to_view = view(to_raw, to.slices...) + move!(dep_mod, to_space, from_space, to_view, from_view) + return +end +=# + +function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) + if from_proc == to_proc + return view(unwrap(slice.chunk), slice.slices...) + else + # Need to copy the underlying data, so collapse the view + from_w = root_worker_id(from_proc) + data = remotecall_fetch(from_w, slice.chunk, slice.slices) do chunk, slices + copy(view(unwrap(chunk), slices...)) + end + return move(from_proc, to_proc, data) + end +end + +Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file diff --git a/test/datadeps.jl b/test/datadeps.jl index cfad2f041..66e41de18 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -1,3 +1,4 @@ +import Dagger: ChunkView, Chunk using LinearAlgebra, Graphs @testset "Memory Aliasing" begin @@ -17,6 +18,70 @@ using LinearAlgebra, Graphs @test s.sz == sizeof(3) end +@testset "ChunkView" begin + DA = rand(Blocks(8, 8), 64, 64) + task1 = DA.chunks[1,1]::DTask + chunk1 = fetch(task1; raw=true)::Chunk + v1 = view(chunk1, :, :) + task2 = DA.chunks[1,2]::DTask + chunk2 = fetch(task2; raw=true)::Chunk + v2 = view(chunk2, :, :) + + for obj in (chunk1, task1) + @testset "Valid Slices" begin + @test view(obj, :, :) isa ChunkView && view(obj, 1:8, 1:8) isa ChunkView + @test view(obj, 1:2:7, :) isa ChunkView && view(obj, :, 2:2:8) isa ChunkView + @test view(obj, 1, :) isa ChunkView && view(obj, :, 1) isa ChunkView + @test view(obj, 3:3, 5:5) isa ChunkView && view(obj, 5:7, 1:2:4) isa ChunkView + @test view(obj, 8, 8) isa ChunkView + @test view(obj, 1:0, :) isa ChunkView + end + + @testset "Dimension Mismatch" begin + @test_throws DimensionMismatch view(obj, :) + @test_throws DimensionMismatch view(obj, :, :, :) + end + + @testset "Int Slice Out of Bounds" begin + @test_throws ArgumentError view(obj, 0, :) + @test_throws ArgumentError view(obj, :, 9) + @test_throws ArgumentError view(obj, 9, 1) + end + + @testset "Range Slice Out of Bounds" begin + @test_throws ArgumentError view(obj, 0:5, :) + @test_throws ArgumentError view(obj, 1:8, 5:10) + @test_throws ArgumentError view(obj, 2:2:10, :) + end + + @testset "Invalid Slice Types" begin + @test_throws DimensionMismatch view(obj, (1:2, :)) + @test_throws ArgumentError view(obj, :, [1, 2]) + end + end + + @test fetch(v1) == fetch(chunk1) + + @test Dagger.memory_space(v1) == Dagger.memory_space(chunk1) + @test Dagger.aliasing(v1) isa Dagger.StridedAliasing + ptr = remotecall_fetch(chunk1.handle.owner, chunk1) do chunk + UInt(pointer(Dagger.unwrap(chunk))) + end + @test Dagger.aliasing(v1).base_ptr.addr == ptr + + @testset "Aliasing" begin + f! = v1 -> begin + @show typeof(v1) v1 + v1 .= 0 + return + end + Dagger.spawn_datadeps() do + Dagger.@spawn f!(InOut(v1)) + end + @test collect(DA)[1:8, 1:8] == zeros(8, 8) + end +end + function with_logs(f) Dagger.enable_logging!(;taskdeps=true, taskargs=true) try