Skip to content

Commit a623002

Browse files
simeonschaubmaleadt
andcommitted
address review comments
Co-authored-by: Tim Besard <[email protected]>
1 parent 2951ddd commit a623002

File tree

5 files changed

+23
-14
lines changed

5 files changed

+23
-14
lines changed

lib/cl/kernel.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ end
158158

159159
function enqueue_kernel(k::Kernel, global_work_size, local_work_size=nothing;
160160
global_work_offset=nothing, wait_on::Vector{Event}=Event[],
161-
device_rng=false)
161+
rng_state=false)
162162
max_work_dim = device().max_work_item_dims
163163
work_dim = length(global_work_size)
164164
if work_dim > max_work_dim
@@ -201,7 +201,7 @@ function enqueue_kernel(k::Kernel, global_work_size, local_work_size=nothing;
201201
# null local size means OpenCL decides
202202
end
203203

204-
if device_rng
204+
if rng_state
205205
if local_work_size !== nothing
206206
num_sub_groups = KernelSubGroupInfo(k, device(), lsize).sub_group_count
207207
else
@@ -249,7 +249,7 @@ function call(
249249
k::Kernel, args...; global_size = (1,), local_size = nothing,
250250
global_work_offset = nothing, wait_on::Vector{Event} = Event[],
251251
indirect_memory::Vector{AbstractMemory} = AbstractMemory[],
252-
device_rng=false,
252+
rng_state=false,
253253
)
254254
set_args!(k, args...)
255255
if !isempty(indirect_memory)
@@ -303,7 +303,7 @@ function call(
303303
clSetKernelExecInfo(k, CL_KERNEL_EXEC_INFO_USM_PTRS_INTEL, sizeof(usm_pointers), usm_pointers)
304304
end
305305
end
306-
enqueue_kernel(k, global_size, local_size; global_work_offset, wait_on, device_rng)
306+
enqueue_kernel(k, global_size, local_size; global_work_offset, wait_on, rng_state)
307307
end
308308

309309
# From `julia/base/reflection.jl`, adjusted to add specialization on `t`.

src/compiler/compilation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
2828

2929
# if this kernel uses our RNG, we should prime the shared state.
3030
# XXX: these transformations should really happen at the Julia IR level...
31-
if haskey(functions(mod), "julia.spirv.random_keys") && job.config.kernel
31+
if haskey(functions(mod), "julia.opencl.random_keys") && job.config.kernel
3232
# insert call to `initialize_rng_state`
3333
f = initialize_rng_state
3434
ft = typeof(f)
@@ -88,8 +88,8 @@ end
8888
function GPUCompiler.finish_linked_module!(@nospecialize(job::OpenCLCompilerJob), mod::LLVM.Module)
8989
for f in GPUCompiler.kernels(mod)
9090
kernel_intrinsics = Dict(
91-
"julia.spirv.random_keys" => (; name = "random_keys", typ = LLVMPtr{UInt32, AS.Workgroup}),
92-
"julia.spirv.random_counters" => (; name = "random_counters", typ = LLVMPtr{UInt32, AS.Workgroup}),
91+
"julia.opencl.random_keys" => (; name = "random_keys", typ = LLVMPtr{UInt32, AS.Workgroup}),
92+
"julia.opencl.random_counters" => (; name = "random_counters", typ = LLVMPtr{UInt32, AS.Workgroup}),
9393
)
9494
GPUCompiler.add_input_arguments!(job, mod, f, kernel_intrinsics)
9595
end

src/compiler/execution.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,7 @@ abstract type AbstractKernel{F, TT} end
159159

160160
quote
161161
indirect_memory = cl.AbstractMemory[]
162-
device_rng = kernel.fun.num_args == $(length(call_args) + 2)
163-
clcall(kernel.fun, $call_tt, $(call_args...); indirect_memory, device_rng, call_kwargs...)
162+
clcall(kernel.fun, $call_tt, $(call_args...); indirect_memory, kernel.rng_state, call_kwargs...)
164163
end
165164
end
166165

@@ -171,6 +170,7 @@ end
171170
struct HostKernel{F,TT} <: AbstractKernel{F,TT}
172171
f::F
173172
fun::cl.Kernel
173+
rng_state::Bool
174174
end
175175

176176

@@ -194,8 +194,10 @@ function clfunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
194194
h = hash(fun, hash(f, hash(tt)))
195195
kernel = get(_kernel_instances, h, nothing)
196196
if kernel === nothing
197+
# TODO: move the `rng_state` check into `OpenCL.compile` so we avoid the API call?
198+
rng_state = fun.num_args == length(tt.parameters) + 2
197199
# create the kernel state object
198-
kernel = HostKernel{F,tt}(f, fun)
200+
kernel = HostKernel{F,tt}(f, fun, rng_state)
199201
_kernel_instances[h] = kernel
200202
end
201203
return kernel::HostKernel{F,tt}

src/device/runtime.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,19 @@ end
2121

2222
@inline @generated kernel_state() = GPUCompiler.kernel_state_value(KernelState)
2323

24+
## intrinsics for adding and accessing additional kernel arguments
25+
26+
# The amount of local shared memory we need for storing RNG state is determined
27+
# dynamically at kernel launch time, so needs to be passed as additional arguments
28+
# to the kernel.
29+
# We define intrinsics that get transformed into additional kernel arguments which
30+
# then get propagated across function calls to the caller.
31+
2432
function additional_arg_intr(mod::LLVM.Module, T_state, name)
25-
state_intr = if haskey(functions(mod), "julia.spirv.$name")
26-
functions(mod)["julia.spirv.$name"]
33+
state_intr = if haskey(functions(mod), "julia.opencl.$name")
34+
functions(mod)["julia.opencl.$name"]
2735
else
28-
LLVM.Function(mod, "julia.spirv.$name", LLVM.FunctionType(T_state))
36+
LLVM.Function(mod, "julia.opencl.$name", LLVM.FunctionType(T_state))
2937
end
3038
push!(function_attributes(state_intr), EnumAttribute("readnone", 0))
3139

test/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
44
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
55
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
6-
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
76
IOCapture = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
87
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
98
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"

0 commit comments

Comments
 (0)