Skip to content

Commit a76fb42

Browse files
committed
GPU: Add initial OpenCL support
1 parent 494e584 commit a76fb42

File tree

5 files changed

+450
-5
lines changed

5 files changed

+450
-5
lines changed

.buildkite/pipeline.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,22 @@ steps:
151151
env:
152152
CI_USE_METAL: "1"
153153

154+
- label: Julia 1.11 (OpenCL)
155+
timeout_in_minutes: 20
156+
<<: *gputest
157+
plugins:
158+
- JuliaCI/julia#v1:
159+
version: "1.11"
160+
- JuliaCI/julia-test#v1:
161+
- JuliaCI/julia-coverage#v1:
162+
codecov: true
163+
agents:
164+
queue: "juliaecosystem"
165+
os: linux
166+
arch: x86_64
167+
env:
168+
CI_USE_OPENCL: "1"
169+
154170
- label: Julia 1 - TimespanLogging
155171
timeout_in_minutes: 20
156172
<<: *test

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
3838
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
3939
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
4040
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
41+
OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2"
4142
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
4243
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
4344
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
@@ -50,6 +51,7 @@ GraphVizSimpleExt = "Colors"
5051
IntelExt = "oneAPI"
5152
JSON3Ext = "JSON3"
5253
MetalExt = "Metal"
54+
OpenCLExt = "OpenCL"
5355
PlotsExt = ["DataFrames", "Plots"]
5456
PythonExt = "PythonCall"
5557
ROCExt = "AMDGPU"
@@ -72,6 +74,7 @@ MacroTools = "0.5"
7274
MemPool = "0.4.12"
7375
Metal = "1.1"
7476
OnlineStats = "1"
77+
OpenCL = "0.10"
7578
oneAPI = "1, 2"
7679
Plots = "1"
7780
PrecompileTools = "1.2"

ext/OpenCLExt.jl

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
module OpenCLExt
2+
3+
export CLArrayDeviceProc
4+
5+
import Dagger, MemPool
6+
import Dagger: CPURAMMemorySpace, Chunk, unwrap
7+
import MemPool: DRef, poolget
8+
import Distributed: myid, remotecall_fetch
9+
import LinearAlgebra
10+
using KernelAbstractions, Adapt
11+
12+
const CPUProc = Union{Dagger.OSProc,Dagger.ThreadProc}
13+
14+
if isdefined(Base, :get_extension)
15+
import OpenCL
16+
else
17+
import ..OpenCL
18+
end
19+
import OpenCL: CLArray, OpenCLBackend, cl
20+
import .cl: Device, Context, CmdQueue
21+
22+
using UUIDs
23+
24+
"Represents a single OpenCL device."
25+
struct CLArrayDeviceProc <: Dagger.Processor
26+
owner::Int
27+
device::Int
28+
end
29+
Dagger.get_parent(proc::CLArrayDeviceProc) = Dagger.OSProc(proc.owner)
30+
Dagger.root_worker_id(proc::CLArrayDeviceProc) = proc.owner
31+
Base.show(io::IO, proc::CLArrayDeviceProc) =
32+
print(io, "CLArrayDeviceProc(worker $(proc.owner), device $(proc.device))")
33+
Dagger.short_name(proc::CLArrayDeviceProc) = "W: $(proc.owner), CL: $(proc.device)"
34+
Dagger.@gpuproc(CLArrayDeviceProc, CLArray)
35+
36+
"Represents the memory space of a single OpenCL device's RAM."
37+
struct CLMemorySpace <: Dagger.MemorySpace
38+
owner::Int
39+
device::Int
40+
end
41+
Dagger.root_worker_id(space::CLMemorySpace) = space.owner
42+
function Dagger.memory_space(x::CLArray)
43+
queue = x.data[].queue
44+
idx = findfirst(==(queue), QUEUES)
45+
return CLMemorySpace(myid(), idx)
46+
end
47+
48+
Dagger.memory_spaces(proc::CLArrayDeviceProc) = Set([CLMemorySpace(proc.owner, proc.device)])
49+
Dagger.processors(space::CLMemorySpace) = Set([CLArrayDeviceProc(space.owner, space.device)])
50+
51+
function to_device(proc::CLArrayDeviceProc)
52+
@assert Dagger.root_worker_id(proc) == myid()
53+
return DEVICES[proc.device]
54+
end
55+
function to_context(proc::CLArrayDeviceProc)
56+
@assert Dagger.root_worker_id(proc) == myid()
57+
return CONTEXTS[proc.device]
58+
end
59+
to_context(handle::Integer) = CONTEXTS[handle]
60+
to_context(dev::Device) = to_context(dev.handle)
61+
62+
function with_context!(handle::Integer)
63+
cl.context!(CONTEXTS[handle])
64+
cl.queue!(QUEUES[handle])
65+
end
66+
function with_context!(proc::CLArrayDeviceProc)
67+
@assert Dagger.root_worker_id(proc) == myid()
68+
with_context!(proc.device)
69+
end
70+
function with_context!(space::CLMemorySpace)
71+
@assert Dagger.root_worker_id(space) == myid()
72+
with_context!(space.device)
73+
end
74+
function with_context(f, x)
75+
old_ctx = cl.context()
76+
old_queue = cl.queue()
77+
78+
with_context!(x)
79+
try
80+
f()
81+
finally
82+
cl.context!(old_ctx)
83+
cl.queue!(old_queue)
84+
end
85+
end
86+
87+
function _sync_with_context(x::Union{Dagger.Processor,Dagger.MemorySpace})
88+
with_context(x) do
89+
cl.finish(cl.queue())
90+
end
91+
end
92+
function sync_with_context(x::Union{Dagger.Processor,Dagger.MemorySpace})
93+
if Dagger.root_worker_id(x) == myid()
94+
_sync_with_context(x)
95+
else
96+
# Do nothing, as we have received our value over a serialization
97+
# boundary, which should synchronize for us
98+
end
99+
end
100+
101+
# Allocations
102+
Dagger.allocate_array_func(::CLArrayDeviceProc, ::typeof(rand)) = OpenCL.rand
103+
Dagger.allocate_array_func(::CLArrayDeviceProc, ::typeof(randn)) = OpenCL.randn
104+
Dagger.allocate_array_func(::CLArrayDeviceProc, ::typeof(ones)) = OpenCL.ones
105+
Dagger.allocate_array_func(::CLArrayDeviceProc, ::typeof(zeros)) = OpenCL.zeros
106+
struct AllocateUndef{S} end
107+
(::AllocateUndef{S})(T, dims::Dims{N}) where {S,N} = CLArray{S,N}(undef, dims)
108+
Dagger.allocate_array_func(::CLArrayDeviceProc, ::Dagger.AllocateUndef{S}) where S = AllocateUndef{S}()
109+
110+
# In-place
111+
# N.B. These methods assume that later operations will implicitly or
112+
# explicitly synchronize with their associated stream
113+
function Dagger.move!(to_space::Dagger.CPURAMMemorySpace, from_space::CLMemorySpace, to::AbstractArray{T,N}, from::AbstractArray{T,N}) where {T,N}
114+
if Dagger.root_worker_id(from_space) == myid()
115+
_sync_with_context(from_space)
116+
with_context!(from_space)
117+
end
118+
copyto!(to, from)
119+
# N.B. DtoH will synchronize
120+
return
121+
end
122+
function Dagger.move!(to_space::CLMemorySpace, from_space::Dagger.CPURAMMemorySpace, to::AbstractArray{T,N}, from::AbstractArray{T,N}) where {T,N}
123+
with_context!(to_space)
124+
copyto!(to, from)
125+
return
126+
end
127+
function Dagger.move!(to_space::CLMemorySpace, from_space::CLMemorySpace, to::AbstractArray{T,N}, from::AbstractArray{T,N}) where {T,N}
128+
sync_with_context(from_space)
129+
with_context!(to_space)
130+
copyto!(to, from)
131+
return
132+
end
133+
134+
# Out-of-place HtoD
135+
function Dagger.move(from_proc::CPUProc, to_proc::CLArrayDeviceProc, x)
136+
with_context(to_proc) do
137+
arr = adapt(CLArray, x)
138+
cl.finish(cl.queue())
139+
return arr
140+
end
141+
end
142+
function Dagger.move(from_proc::CPUProc, to_proc::CLArrayDeviceProc, x::Chunk)
143+
from_w = Dagger.root_worker_id(from_proc)
144+
to_w = Dagger.root_worker_id(to_proc)
145+
@assert myid() == to_w
146+
cpu_data = remotecall_fetch(unwrap, from_w, x)
147+
with_context(to_proc) do
148+
arr = adapt(CLArray, cpu_data)
149+
cl.finish(cl.queue())
150+
return arr
151+
end
152+
end
153+
function Dagger.move(from_proc::CPUProc, to_proc::CLArrayDeviceProc, x::CLArray)
154+
queue = x.data[].queue
155+
if queue == QUEUES[to_proc.device]
156+
return x
157+
end
158+
with_context(to_proc) do
159+
_x = similar(x)
160+
copyto!(_x, x)
161+
cl.finish(cl.queue())
162+
return _x
163+
end
164+
end
165+
166+
# Out-of-place DtoH
167+
function Dagger.move(from_proc::CLArrayDeviceProc, to_proc::CPUProc, x)
168+
with_context(from_proc) do
169+
cl.finish(cl.queue())
170+
_x = adapt(Array, x)
171+
cl.finish(cl.queue())
172+
return _x
173+
end
174+
end
175+
function Dagger.move(from_proc::CLArrayDeviceProc, to_proc::CPUProc, x::Chunk)
176+
from_w = Dagger.root_worker_id(from_proc)
177+
to_w = Dagger.root_worker_id(to_proc)
178+
@assert myid() == to_w
179+
remotecall_fetch(from_w, x) do x
180+
arr = unwrap(x)
181+
return Dagger.move(from_proc, to_proc, arr)
182+
end
183+
end
184+
function Dagger.move(from_proc::CLArrayDeviceProc, to_proc::CPUProc, x::CLArray{T,N}) where {T,N}
185+
with_context(from_proc) do
186+
cl.finish(cl.queue())
187+
_x = Array{T,N}(undef, size(x))
188+
copyto!(_x, x)
189+
cl.finish(cl.queue())
190+
return _x
191+
end
192+
end
193+
194+
# Out-of-place DtoD
195+
function Dagger.move(from_proc::CLArrayDeviceProc, to_proc::CLArrayDeviceProc, x::Dagger.Chunk{T}) where T<:CLArray
196+
if from_proc == to_proc
197+
# Same process and GPU, no change
198+
arr = unwrap(x)
199+
_sync_with_context(from_proc)
200+
return arr
201+
elseif Dagger.root_worker_id(from_proc) == Dagger.root_worker_id(to_proc)
202+
# Same process but different GPUs, use DtoD copy
203+
from_arr = unwrap(x)
204+
_sync_with_context(from_proc)
205+
return with_context(to_proc) do
206+
to_arr = similar(from_arr)
207+
copyto!(to_arr, from_arr)
208+
cl.finish(cl.queue())
209+
return to_arr
210+
end
211+
else
212+
# Different node, use DtoH, serialization, HtoD
213+
return CLArray(remotecall_fetch(from_proc.owner, x) do x
214+
Array(unwrap(x))
215+
end)
216+
end
217+
end
218+
219+
# Adapt generic functions
220+
Dagger.move(from_proc::CPUProc, to_proc::CLArrayDeviceProc, x::Function) = x
221+
Dagger.move(from_proc::CPUProc, to_proc::CLArrayDeviceProc, x::Chunk{T}) where {T<:Function} =
222+
Dagger.move(from_proc, to_proc, fetch(x))
223+
224+
# Task execution
225+
function Dagger.execute!(proc::CLArrayDeviceProc, f, args...; kwargs...)
226+
@nospecialize f args kwargs
227+
tls = Dagger.get_tls()
228+
task = Threads.@spawn begin
229+
Dagger.set_tls!(tls)
230+
with_context!(proc)
231+
result = Base.@invokelatest f(args...; kwargs...)
232+
# N.B. Synchronization must be done when accessing result or args
233+
return result
234+
end
235+
236+
try
237+
fetch(task)
238+
catch err
239+
stk = current_exceptions(task)
240+
err, frames = stk[1]
241+
rethrow(CapturedException(err, frames))
242+
end
243+
end
244+
245+
Dagger.gpu_processor(::Val{:OpenCL}) = CLArrayDeviceProc
246+
Dagger.gpu_can_compute(::Val{:OpenCL}) = length(cl.platforms()) > 0
247+
Dagger.gpu_kernel_backend(::CLArrayDeviceProc) = OpenCLBackend()
248+
Dagger.gpu_with_device(f, proc::CLArrayDeviceProc) =
249+
cl.device!(f, proc.device)
250+
function Dagger.gpu_synchronize(proc::CLArrayDeviceProc)
251+
with_context(proc) do
252+
cl.finish(QUEUES[proc.device])
253+
end
254+
end
255+
function Dagger.gpu_synchronize(::Val{:OpenCL})
256+
for idx in keys(DEVICES)
257+
_sync_with_context(CLArrayDeviceProc(myid(), idx))
258+
end
259+
end
260+
261+
Dagger.to_scope(::Val{:cl_device}, sc::NamedTuple) =
262+
Dagger.to_scope(Val{:cl_devices}(), merge(sc, (;cl_devices=[sc.cl_device])))
263+
Dagger.scope_key_precedence(::Val{:cl_device}) = 1
264+
function Dagger.to_scope(::Val{:cl_devices}, sc::NamedTuple)
265+
if haskey(sc, :worker)
266+
workers = Int[sc.worker]
267+
elseif haskey(sc, :workers) && sc.workers != Colon()
268+
workers = sc.workers
269+
else
270+
workers = map(gproc->gproc.pid, Dagger.procs(Dagger.Sch.eager_context()))
271+
end
272+
scopes = Dagger.ExactScope[]
273+
dev_ids = sc.cl_devices
274+
for worker in workers
275+
procs = Dagger.get_processors(Dagger.OSProc(worker))
276+
for proc in procs
277+
proc isa CLArrayDeviceProc || continue
278+
if dev_ids == Colon() || proc.device in dev_ids
279+
scope = Dagger.ExactScope(proc)
280+
push!(scopes, scope)
281+
end
282+
end
283+
end
284+
return Dagger.UnionScope(scopes)
285+
end
286+
Dagger.scope_key_precedence(::Val{:cl_devices}) = 1
287+
288+
const DEVICES = Dict{Int, Device}()
289+
const CONTEXTS = Dict{Int, Context}()
290+
const QUEUES = Dict{Int, CmdQueue}()
291+
292+
function __init__()
293+
# FIXME: Support multiple platforms
294+
if length(cl.platforms()) > 0
295+
platform = cl.default_platform()
296+
for (idx, dev) in enumerate(cl.devices(platform))
297+
@debug "Registering OpenCL device processor with Dagger: $dev"
298+
Dagger.add_processor_callback!("clarray_device_$(idx)") do
299+
proc = CLArrayDeviceProc(myid(), idx)
300+
cl.device!(dev) do
301+
DEVICES[idx] = dev
302+
CONTEXTS[idx] = cl.context()
303+
QUEUES[idx] = cl.queue()
304+
end
305+
return proc
306+
end
307+
end
308+
end
309+
end
310+
311+
end # module OpenCLExt

0 commit comments

Comments
 (0)