Skip to content

Commit 1e0cc54

Browse files
authored
Merge pull request #18 from JuliaGPU/jps/rocm-revamp
Use package extensions, ROCm/AMDGPU revamp
2 parents 0dc5aa4 + 5353722 commit 1e0cc54

File tree

7 files changed

+324
-183
lines changed

7 files changed

+324
-183
lines changed

Project.toml

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,37 @@
11
name = "DaggerGPU"
22
uuid = "68e73e28-2238-4d5a-bf97-e5d4aa3c4be2"
33
authors = ["Julian P Samaroo <[email protected]>"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
7+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
78
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
9+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
810
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
911
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1012
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1113
MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94"
14+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
1215
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1316
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1417

18+
[weakdeps]
19+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
20+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
21+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
22+
23+
[extensions]
24+
CUDAExt = "CUDA"
25+
MetalExt = "Metal"
26+
ROCExt = "AMDGPU"
27+
1528
[compat]
29+
AMDGPU = "0.4"
1630
Adapt = "1, 2, 3"
17-
Dagger = "0.13.3, 0.14, 0.15, 0.16"
18-
KernelAbstractions = "0.5, 0.6, 0.7, 0.8"
31+
CUDA = "3, 4"
32+
Dagger = "0.17"
33+
KernelAbstractions = "0.9"
1934
MemPool = "0.3, 0.4"
35+
Metal = "0.3, 0.4"
2036
Requires = "1"
21-
julia = "1.6"
37+
julia = "1.7"
Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,28 @@
1-
using .CUDA
2-
import .CUDA: CuDevice, CuContext, devices, attribute
3-
4-
using UUIDs
1+
module CUDAExt
52

63
export CuArrayDeviceProc
74

5+
import Dagger, DaggerGPU, MemPool
6+
import Distributed: myid, remotecall_fetch
7+
8+
const CPUProc = Union{Dagger.OSProc,Dagger.ThreadProc}
9+
10+
if isdefined(Base, :get_extension)
11+
import CUDA
12+
else
13+
import ..CUDA
14+
end
15+
import CUDA: CuDevice, CuContext, CuArray, CUDABackend, devices, attribute
16+
17+
using UUIDs
18+
819
"Represents a single CUDA GPU device."
920
struct CuArrayDeviceProc <: Dagger.Processor
1021
owner::Int
1122
device::Int
1223
device_uuid::UUID
1324
end
14-
@gpuproc(CuArrayDeviceProc, CuArray)
25+
DaggerGPU.@gpuproc(CuArrayDeviceProc, CuArray)
1526
Dagger.get_parent(proc::CuArrayDeviceProc) = Dagger.OSProc(proc.owner)
1627

1728
# function can_access(this, peer)
@@ -23,10 +34,10 @@ Dagger.get_parent(proc::CuArrayDeviceProc) = Dagger.OSProc(proc.owner)
2334
function Dagger.move(from::CuArrayDeviceProc, to::CuArrayDeviceProc, x::Dagger.Chunk{T}) where T<:CuArray
2435
if from == to
2536
# Same process and GPU, no change
26-
poolget(x.handle)
37+
MemPool.poolget(x.handle)
2738
elseif from.owner == to.owner
2839
# Same process but different GPUs, use DtoD copy
29-
from_arr = poolget(x.handle)
40+
from_arr = MemPool.poolget(x.handle)
3041
to_arr = CUDA.device!(to.device) do
3142
CuArray{T,N}(undef, size)
3243
end
@@ -35,7 +46,7 @@ function Dagger.move(from::CuArrayDeviceProc, to::CuArrayDeviceProc, x::Dagger.C
3546
elseif Dagger.system_uuid(from.owner) == Dagger.system_uuid(to.owner)
3647
# Same node, we can use IPC
3748
ipc_handle, eT, shape = remotecall_fetch(from.owner, x.handle) do h
38-
arr = poolget(h)
49+
arr = MemPool.poolget(h)
3950
ipc_handle_ref = Ref{CUDA.CUipcMemHandle}()
4051
GC.@preserve arr begin
4152
CUDA.cuIpcGetMemHandle(ipc_handle_ref, pointer(arr))
@@ -64,41 +75,56 @@ function Dagger.move(from::CuArrayDeviceProc, to::CuArrayDeviceProc, x::Dagger.C
6475
# Different node, use DtoH, serialization, HtoD
6576
# TODO UCX
6677
CuArray(remotecall_fetch(from.owner, x.handle) do h
67-
Array(poolget(h))
78+
Array(MemPool.poolget(h))
6879
end)
6980
end
7081
end
7182

72-
function Dagger.execute!(proc::CuArrayDeviceProc, func, args...)
83+
function Dagger.execute!(proc::CuArrayDeviceProc, f, args...; kwargs...)
84+
@nospecialize f args kwargs
7385
tls = Dagger.get_tls()
7486
task = Threads.@spawn begin
7587
Dagger.set_tls!(tls)
7688
CUDA.device!(proc.device)
77-
CUDA.@sync func(args...)
89+
result = Base.@invokelatest f(args...; kwargs...)
90+
CUDA.synchronize()
91+
return result
7892
end
93+
7994
try
8095
fetch(task)
8196
catch err
82-
@static if VERSION >= v"1.1"
83-
stk = Base.catch_stack(task)
84-
err, frames = stk[1]
85-
rethrow(CapturedException(err, frames))
86-
else
87-
rethrow(task.result)
88-
end
97+
stk = current_exceptions(task)
98+
err, frames = stk[1]
99+
rethrow(CapturedException(err, frames))
89100
end
90101
end
91102
Base.show(io::IO, proc::CuArrayDeviceProc) =
92-
print(io, "CuArrayDeviceProc on worker $(proc.owner), device $(proc.device), uuid $(proc.device_uuid)")
103+
print(io, "CuArrayDeviceProc(worker $(proc.owner), device $(proc.device), uuid $(proc.device_uuid))")
93104

94-
processor(::Val{:CUDA}) = CuArrayDeviceProc
95-
cancompute(::Val{:CUDA}) = CUDA.has_cuda()
96-
kernel_backend(::CuArrayDeviceProc) = CUDADevice()
105+
DaggerGPU.processor(::Val{:CUDA}) = CuArrayDeviceProc
106+
DaggerGPU.cancompute(::Val{:CUDA}) = CUDA.has_cuda()
107+
DaggerGPU.kernel_backend(::CuArrayDeviceProc) = CUDABackend()
108+
DaggerGPU.with_device(f, proc::CuArrayDeviceProc) =
109+
CUDA.device!(f, proc.device)
97110

98-
if CUDA.has_cuda()
99-
for dev in devices()
100-
Dagger.add_processor_callback!("cuarray_device_$(dev.handle)") do
101-
CuArrayDeviceProc(Distributed.myid(), dev.handle, CUDA.uuid(dev))
111+
function Dagger.to_scope(::Val{:cuda_gpu}, sc::NamedTuple)
112+
worker = get(sc, :worker, 1)
113+
dev_id = sc.cuda_gpu
114+
dev = collect(CUDA.devices())[dev_id]
115+
return Dagger.ExactScope(CuArrayDeviceProc(worker, dev_id-1, CUDA.uuid(dev)))
116+
end
117+
Dagger.scope_key_precedence(::Val{:cuda_gpu}) = 1
118+
119+
function __init__()
120+
if CUDA.has_cuda()
121+
for dev in CUDA.devices()
122+
@debug "Registering CUDA GPU processor with Dagger: $dev"
123+
Dagger.add_processor_callback!("cuarray_device_$(dev.handle)") do
124+
CuArrayDeviceProc(myid(), dev.handle, CUDA.uuid(dev))
125+
end
102126
end
103127
end
104128
end
129+
130+
end # module CUDAExt
Lines changed: 64 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,36 @@
1-
using .Metal
2-
import .Metal: MtlArray, MtlDevice
1+
module MetalExt
32

4-
struct MtlArrayDeviceProc <: Dagger.Processor
5-
owner::Int
6-
device_id::UInt64
7-
end
3+
export MtlArrayDeviceProc
84

9-
# Assume that we can run anything.
10-
Dagger.iscompatible_func(proc::MtlArrayDeviceProc, opts, f) = true
11-
Dagger.iscompatible_arg(proc::MtlArrayDeviceProc, opts, x) = true
5+
import Dagger, DaggerGPU
6+
import Distributed: myid
127

13-
# CPUs shouldn't process our array type.
14-
Dagger.iscompatible_arg(proc::Dagger.ThreadProc, opts, x::MtlArray) = false
8+
const CPUProc = Union{Dagger.OSProc,Dagger.ThreadProc}
159

16-
function Dagger.move(from_proc::OSProc, to_proc::MtlArrayDeviceProc, x::Chunk)
17-
from_pid = from_proc.pid
18-
to_pid = Dagger.get_parent(to_proc).pid
19-
@assert myid() == to_pid
20-
21-
return Dagger.move(from_proc, to_proc, remotecall_fetch(x->poolget(x.handle), from_pid, x))
10+
if isdefined(Base, :get_extension)
11+
import Metal
12+
else
13+
import ..Metal
2214
end
15+
import Metal: MtlArray, MetalBackend
16+
const MtlDevice = Metal.MTL.MTLDeviceInstance
2317

24-
function Dagger.move(from_proc::MtlArrayDeviceProc, to_proc::OSProc, x::Chunk)
25-
from_pid = Dagger.get_parent(from_proc).pid
26-
to_pid = to_proc.pid
27-
@assert myid() == to_pid
28-
29-
return remotecall_fetch(from_pid, x) do x
30-
mtlarray = poolget(x.handle)
31-
return Dagger.move(from_proc, to_proc, mtlarray)
32-
end
18+
struct MtlArrayDeviceProc <: Dagger.Processor
19+
owner::Int
20+
device_id::UInt64
3321
end
3422

35-
function Dagger.move(
36-
from_proc::OSProc,
23+
DaggerGPU.@gpuproc(MtlArrayDeviceProc, MtlArray)
24+
Dagger.get_parent(proc::MtlArrayDeviceProc) = Dagger.OSProc(proc.owner)
25+
26+
function DaggerGPU.move_optimized(
27+
from_proc::CPUProc,
3728
to_proc::MtlArrayDeviceProc,
38-
x::Array{T, N}
39-
) where {T, N}
29+
x::Array
30+
)
31+
# FIXME
32+
return nothing
33+
4034
# If we have unified memory, we can try casting the `Array` to `MtlArray`.
4135
device = _get_metal_device(to_proc)
4236

@@ -45,68 +39,75 @@ function Dagger.move(
4539
marray !== nothing && return marray
4640
end
4741

48-
return adapt(MtlArray, x)
42+
return nothing
4943
end
5044

51-
function Dagger.move(from_proc::OSProc, to_proc::MtlArrayDeviceProc, x)
52-
adapt(MtlArray, x)
53-
end
5445

55-
function Dagger.move(
46+
function DaggerGPU.move_optimized(
5647
from_proc::MtlArrayDeviceProc,
57-
to_proc::OSProc,
58-
x::Array{T, N}
59-
) where {T, N}
48+
to_proc::CPUProc,
49+
x::Array
50+
)
51+
# FIXME
52+
return nothing
53+
6054
# If we have unified memory, we can just cast the `MtlArray` to an `Array`.
6155
device = _get_metal_device(from_proc)
6256

6357
if (device !== nothing) && device.hasUnifiedMemory
6458
return unsafe_wrap(Array{T}, x, size(x))
65-
else
66-
return adapt(Array, x)
6759
end
68-
end
6960

70-
function Dagger.move(from_proc::MtlArrayDeviceProc, to_proc::OSProc, x)
71-
adapt(Array, x)
61+
return nothing
7262
end
7363

74-
Dagger.get_parent(proc::MtlArrayDeviceProc) = Dagger.OSProc(proc.owner)
75-
76-
function Dagger.execute!(proc::MtlArrayDeviceProc, func, args...)
64+
function Dagger.execute!(proc::MtlArrayDeviceProc, f, args...; kwargs...)
65+
@nospecialize f args kwargs
7766
tls = Dagger.get_tls()
7867
task = Threads.@spawn begin
7968
Dagger.set_tls!(tls)
80-
Metal.@sync func(args...)
69+
result = Base.@invokelatest f(args...; kwargs...)
70+
Metal.synchronize()
71+
return result
8172
end
8273

8374
try
8475
fetch(task)
8576
catch err
86-
@static if VERSION >= v"1.1"
87-
stk = Base.catch_stack(task)
88-
err, frames = stk[1]
89-
rethrow(CapturedException(err, frames))
90-
else
91-
rethrow(task.result)
92-
end
77+
stk = current_exceptions(task)
78+
err, frames = stk[1]
79+
rethrow(CapturedException(err, frames))
9380
end
9481
end
9582

9683
function Base.show(io::IO, proc::MtlArrayDeviceProc)
97-
print(io, "MtlArrayDeviceProc on worker $(proc.owner), device ($(something(_get_metal_device(proc)).name))")
84+
print(io, "MtlArrayDeviceProc(worker $(proc.owner), device $(something(_get_metal_device(proc)).name))")
9885
end
9986

100-
processor(::Val{:Metal}) = MtlArrayDeviceProc
101-
cancompute(::Val{:Metal}) = length(Metal.devices()) >= 1
102-
kernel_backend(proc::MtlArrayDeviceProc) = _get_metal_device(proc)
87+
DaggerGPU.processor(::Val{:Metal}) = MtlArrayDeviceProc
88+
DaggerGPU.cancompute(::Val{:Metal}) = Metal.functional()
89+
DaggerGPU.kernel_backend(proc::MtlArrayDeviceProc) = MetalBackend()
90+
# TODO: Switch devices
91+
DaggerGPU.with_device(f, proc::MtlArrayDeviceProc) = f()
92+
93+
function Dagger.to_scope(::Val{:metal_gpu}, sc::NamedTuple)
94+
worker = get(sc, :worker, 1)
95+
dev_id = sc.metal_gpu
96+
dev = Metal.devices()[dev_id]
97+
return Dagger.ExactScope(MtlArrayDeviceProc(worker, dev.registryID))
98+
end
99+
Dagger.scope_key_precedence(::Val{:metal_gpu}) = 1
103100

104-
for dev in Metal.devices()
105-
Dagger.add_processor_callback!("metal_device_$(dev.registryID)") do
106-
MtlArrayDeviceProc(Distributed.myid(), dev.registryID)
101+
function __init__()
102+
for dev in Metal.devices()
103+
@debug "Registering Metal GPU processor with Dagger: $dev"
104+
Dagger.add_processor_callback!("metal_device_$(dev.registryID)") do
105+
MtlArrayDeviceProc(myid(), dev.registryID)
106+
end
107107
end
108108
end
109109

110+
110111
################################################################################
111112
# Private functions
112113
################################################################################
@@ -149,3 +150,5 @@ function _get_metal_device(proc::MtlArrayDeviceProc)
149150
return devices[id]
150151
end
151152
end
153+
154+
end # module MetalExt

0 commit comments

Comments
 (0)