1- using . Metal
2- import . Metal: MtlArray, MtlDevice
1+ module MetalExt
32
4- struct MtlArrayDeviceProc <: Dagger.Processor
5- owner:: Int
6- device_id:: UInt64
7- end
3+ export MtlArrayDeviceProc
84
9- # Assume that we can run anything.
10- Dagger. iscompatible_func (proc:: MtlArrayDeviceProc , opts, f) = true
11- Dagger. iscompatible_arg (proc:: MtlArrayDeviceProc , opts, x) = true
5+ import Dagger, DaggerGPU
6+ import Distributed: myid
127
13- # CPUs shouldn't process our array type.
14- Dagger. iscompatible_arg (proc:: Dagger.ThreadProc , opts, x:: MtlArray ) = false
8+ const CPUProc = Union{Dagger. OSProc,Dagger. ThreadProc}
159
16- function Dagger. move (from_proc:: OSProc , to_proc:: MtlArrayDeviceProc , x:: Chunk )
17- from_pid = from_proc. pid
18- to_pid = Dagger. get_parent (to_proc). pid
19- @assert myid () == to_pid
20-
21- return Dagger. move (from_proc, to_proc, remotecall_fetch (x-> poolget (x. handle), from_pid, x))
10+ if isdefined (Base, :get_extension )
11+ import Metal
12+ else
13+ import .. Metal
2214end
15+ import Metal: MtlArray, MetalBackend
16+ const MtlDevice = Metal. MTL. MTLDeviceInstance
2317
24- function Dagger. move (from_proc:: MtlArrayDeviceProc , to_proc:: OSProc , x:: Chunk )
25- from_pid = Dagger. get_parent (from_proc). pid
26- to_pid = to_proc. pid
27- @assert myid () == to_pid
28-
29- return remotecall_fetch (from_pid, x) do x
30- mtlarray = poolget (x. handle)
31- return Dagger. move (from_proc, to_proc, mtlarray)
32- end
18+ struct MtlArrayDeviceProc <: Dagger.Processor
19+ owner:: Int
20+ device_id:: UInt64
3321end
3422
35- function Dagger. move (
36- from_proc:: OSProc ,
23+ DaggerGPU. @gpuproc (MtlArrayDeviceProc, MtlArray)
24+ Dagger. get_parent (proc:: MtlArrayDeviceProc ) = Dagger. OSProc (proc. owner)
25+
26+ function DaggerGPU. move_optimized (
27+ from_proc:: CPUProc ,
3728 to_proc:: MtlArrayDeviceProc ,
38- x:: Array{T, N}
39- ) where {T, N}
29+ x:: Array
30+ )
31+ # FIXME
32+ return nothing
33+
4034 # If we have unified memory, we can try casting the `Array` to `MtlArray`.
4135 device = _get_metal_device (to_proc)
4236
@@ -45,68 +39,75 @@ function Dagger.move(
4539 marray != = nothing && return marray
4640 end
4741
48- return adapt (MtlArray, x)
42+ return nothing
4943end
5044
51- function Dagger. move (from_proc:: OSProc , to_proc:: MtlArrayDeviceProc , x)
52- adapt (MtlArray, x)
53- end
5445
55- function Dagger . move (
46+ function DaggerGPU . move_optimized (
5647 from_proc:: MtlArrayDeviceProc ,
57- to_proc:: OSProc ,
58- x:: Array{T, N}
59- ) where {T, N}
48+ to_proc:: CPUProc ,
49+ x:: Array
50+ )
51+ # FIXME
52+ return nothing
53+
6054 # If we have unified memory, we can just cast the `MtlArray` to an `Array`.
6155 device = _get_metal_device (from_proc)
6256
6357 if (device != = nothing ) && device. hasUnifiedMemory
6458 return unsafe_wrap (Array{T}, x, size (x))
65- else
66- return adapt (Array, x)
6759 end
68- end
6960
70- function Dagger. move (from_proc:: MtlArrayDeviceProc , to_proc:: OSProc , x)
71- adapt (Array, x)
61+ return nothing
7262end
7363
74- Dagger. get_parent (proc:: MtlArrayDeviceProc ) = Dagger. OSProc (proc. owner)
75-
76- function Dagger. execute! (proc:: MtlArrayDeviceProc , func, args... )
64+ function Dagger. execute! (proc:: MtlArrayDeviceProc , f, args... ; kwargs... )
65+ @nospecialize f args kwargs
7766 tls = Dagger. get_tls ()
7867 task = Threads. @spawn begin
7968 Dagger. set_tls! (tls)
80- Metal. @sync func (args... )
69+ result = Base. @invokelatest f (args... ; kwargs... )
70+ Metal. synchronize ()
71+ return result
8172 end
8273
8374 try
8475 fetch (task)
8576 catch err
86- @static if VERSION >= v " 1.1"
87- stk = Base. catch_stack (task)
88- err, frames = stk[1 ]
89- rethrow (CapturedException (err, frames))
90- else
91- rethrow (task. result)
92- end
77+ stk = current_exceptions (task)
78+ err, frames = stk[1 ]
79+ rethrow (CapturedException (err, frames))
9380 end
9481end
9582
9683function Base. show (io:: IO , proc:: MtlArrayDeviceProc )
97- print (io, " MtlArrayDeviceProc on worker $(proc. owner) , device ( $(something (_get_metal_device (proc)). name) )" )
84+ print (io, " MtlArrayDeviceProc( worker $(proc. owner) , device $(something (_get_metal_device (proc)). name) )" )
9885end
9986
100- processor (:: Val{:Metal} ) = MtlArrayDeviceProc
101- cancompute (:: Val{:Metal} ) = length (Metal. devices ()) >= 1
102- kernel_backend (proc:: MtlArrayDeviceProc ) = _get_metal_device (proc)
87+ DaggerGPU. processor (:: Val{:Metal} ) = MtlArrayDeviceProc
88+ DaggerGPU. cancompute (:: Val{:Metal} ) = Metal. functional ()
89+ DaggerGPU. kernel_backend (proc:: MtlArrayDeviceProc ) = MetalBackend ()
90+ # TODO : Switch devices
91+ DaggerGPU. with_device (f, proc:: MtlArrayDeviceProc ) = f ()
92+
93+ function Dagger. to_scope (:: Val{:metal_gpu} , sc:: NamedTuple )
94+ worker = get (sc, :worker , 1 )
95+ dev_id = sc. metal_gpu
96+ dev = Metal. devices ()[dev_id]
97+ return Dagger. ExactScope (MtlArrayDeviceProc (worker, dev. registryID))
98+ end
99+ Dagger. scope_key_precedence (:: Val{:metal_gpu} ) = 1
103100
104- for dev in Metal. devices ()
105- Dagger. add_processor_callback! (" metal_device_$(dev. registryID) " ) do
106- MtlArrayDeviceProc (Distributed. myid (), dev. registryID)
101+ function __init__ ()
102+ for dev in Metal. devices ()
103+ @debug " Registering Metal GPU processor with Dagger: $dev "
104+ Dagger. add_processor_callback! (" metal_device_$(dev. registryID) " ) do
105+ MtlArrayDeviceProc (myid (), dev. registryID)
106+ end
107107 end
108108end
109109
110+
110111# ###############################################################################
111112# Private functions
112113# ###############################################################################
@@ -149,3 +150,5 @@ function _get_metal_device(proc::MtlArrayDeviceProc)
149150 return devices[id]
150151 end
151152end
153+
154+ end # module MetalExt
0 commit comments