Skip to content

Commit 6b7b35c

Browse files
committed
TEMP datadeps: Support CUDA execution
1 parent dd8566d commit 6b7b35c

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

ext/CUDAExt.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,22 @@ function Dagger.memory_space(x::CuArray)
4848
device_uuid = CUDA.uuid(dev)
4949
return CUDAVRAMMemorySpace(myid(), device_id, device_uuid)
5050
end
51+
function Dagger.aliasing(x::CuArray{T}) where T
52+
space = Dagger.memory_space(x)
53+
S = typeof(space)
54+
cuptr = pointer(x)
55+
rptr = Dagger.RemotePtr{Cvoid}(UInt64(cuptr), space)
56+
return Dagger.ContiguousAliasing(Dagger.MemorySpan{S}(rptr, sizeof(T)*length(x)))
57+
end
58+
Dagger.pointer_in_space(ptr::UInt64, space::CUDAVRAMMemorySpace) =
59+
CUDA.CuPtr{UInt8}(ptr)
60+
# TODO: Make async=true
61+
Dagger.unsafe_copyto_spaces!(to_space::CUDAVRAMMemorySpace, from_space::CUDAVRAMMemorySpace, to_ptr, from_ptr, len::UInt64) =
62+
unsafe_copyto!(to_ptr, from_ptr, len; async=false)
63+
Dagger.unsafe_copyto_spaces!(to_space::CUDAVRAMMemorySpace, from_space::CPURAMMemorySpace, to_ptr, from_ptr, len::UInt64) =
64+
unsafe_copyto!(to_ptr, from_ptr, len; async=false)
65+
Dagger.unsafe_copyto_spaces!(to_space::CPURAMMemorySpace, from_space::CUDAVRAMMemorySpace, to_ptr, from_ptr, len::UInt64) =
66+
unsafe_copyto!(to_ptr, from_ptr, len; async=false)
5167

5268
Dagger.memory_spaces(proc::CuArrayDeviceProc) = Set([CUDAVRAMMemorySpace(proc.owner, proc.device, proc.device_uuid)])
5369
Dagger.processors(space::CUDAVRAMMemorySpace) = Set([CuArrayDeviceProc(space.owner, space.device, space.device_uuid)])

src/datadeps/remainders.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ function compute_remainder_for_arg!(state::DataDepsState,
143143
push!(target_ainfos, LocalMemorySpan.(spans))
144144
end
145145
nspans = length(first(target_ainfos))
146+
@assert all(==(nspans), length.(target_ainfos)) "Aliasing info for $(typeof(arg_w.arg))[$(arg_w.dep_mod)] has different number of spans in different memory spaces"
146147

147148
# FIXME: This is a hack to ensure that we don't miss any history generated by aliasing(...)
148149
for entry in state.arg_history[arg_w]
@@ -433,16 +434,18 @@ end
433434

434435
# Main copy function for RemainderAliasing
435436
function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S
437+
# TODO: Support direct copy between GPU memory spaces
438+
436439
# Copy the data from the source object
437440
copies = remotecall_fetch(root_worker_id(from_space), dep_mod) do dep_mod
438441
len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans)
439442
copies = Vector{UInt8}(undef, len)
440443
offset = 1
441444
GC.@preserve copies begin
442445
for (from_span, _) in dep_mod.spans
443-
from_ptr = Ptr{UInt8}(from_span.ptr)
446+
from_ptr = pointer_in_space(from_span.ptr, from_space)
444447
to_ptr = Ptr{UInt8}(pointer(copies, offset))
445-
unsafe_copyto!(to_ptr, from_ptr, from_span.len)
448+
unsafe_copyto_spaces!(CPURAMMemorySpace(), from_space, to_ptr, from_ptr, from_span.len)
446449
offset += from_span.len
447450
end
448451
end
@@ -455,8 +458,8 @@ function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space:
455458
GC.@preserve copies begin
456459
for (_, to_span) in dep_mod.spans
457460
from_ptr = Ptr{UInt8}(pointer(copies, offset))
458-
to_ptr = Ptr{UInt8}(to_span.ptr)
459-
unsafe_copyto!(to_ptr, from_ptr, to_span.len)
461+
to_ptr = pointer_in_space(to_span.ptr, to_space)
462+
unsafe_copyto_spaces!(to_space, CPURAMMemorySpace(), to_ptr, from_ptr, to_span.len)
460463
offset += to_span.len
461464
end
462465
@assert offset == length(copies)+1
@@ -467,3 +470,7 @@ function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space:
467470

468471
return
469472
end
473+
474+
pointer_in_space(ptr::UInt64, space::CPURAMMemorySpace) = Ptr{UInt8}(ptr)
475+
unsafe_copyto_spaces!(to_space, from_space, to_ptr, from_ptr, len::UInt64) =
476+
unsafe_copyto!(to_ptr, from_ptr, len)

src/memory-spaces.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
struct CPURAMMemorySpace <: MemorySpace
22
owner::Int
33
end
4+
CPURAMMemorySpace() = CPURAMMemorySpace(myid())
45
root_worker_id(space::CPURAMMemorySpace) = space.owner
56

67
memory_space(x) = CPURAMMemorySpace(myid())
@@ -87,7 +88,8 @@ function type_may_alias(::Type{T}) where T
8788
return false
8889
end
8990

90-
may_alias(::MemorySpace, ::MemorySpace) = true
91+
may_alias(::MemorySpace, ::MemorySpace) = false
92+
may_alias(space1::M, space2::M) where M<:MemorySpace = space1 == space2
9193
may_alias(space1::CPURAMMemorySpace, space2::CPURAMMemorySpace) = space1.owner == space2.owner
9294

9395
abstract type AbstractAliasing end
@@ -571,7 +573,7 @@ end
571573
function will_alias(x_span::MemorySpan, y_span::MemorySpan)
572574
may_alias(x_span.ptr.space, y_span.ptr.space) || return false
573575
# FIXME: Allow pointer conversion instead of just failing
574-
@assert x_span.ptr.space == y_span.ptr.space
576+
@assert x_span.ptr.space == y_span.ptr.space "Memory spans are in different spaces: $(x_span.ptr.space) vs. $(y_span.ptr.space)"
575577
x_end = x_span.ptr + x_span.len - 1
576578
y_end = y_span.ptr + y_span.len - 1
577579
return x_span.ptr <= y_end && y_span.ptr <= x_end

0 commit comments

Comments
 (0)