Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,33 @@ OpenCL_jll = "6cb37087-e8b6-5417-8430-1f242f1e46e4"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c"
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[sources]
SPIRVIntrinsics = {path = "lib/intrinsics"}

[compat]
Adapt = "4"
GPUArrays = "11.2.1"
GPUCompiler = "1.6"
GPUCompiler = "1.7.1"
KernelAbstractions = "0.9.2"
LLVM = "9.1"
LinearAlgebra = "1"
OpenCL_jll = "=2024.10.24"
Preferences = "1"
Printf = "1"
Random = "1"
Random123 = "1.7.1"
RandomNumbers = "1.6.0"
Reexport = "1"
SPIRVIntrinsics = "0.5"
SPIRV_LLVM_Backend_jll = "20"
SPIRV_Tools_jll = "2025.1"
StaticArrays = "1"
julia = "1.10"

[sources]
SPIRVIntrinsics = {path="lib/intrinsics"}
69 changes: 66 additions & 3 deletions lib/cl/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,59 @@ function Base.getproperty(ki::KernelWorkGroupInfo, s::Symbol)
end
end

struct KernelSubGroupInfo
kernel::Kernel
device::Device
local_work_size::Vector{Csize_t}
end
sub_group_info(k::Kernel, d::Device, l) = KernelSubGroupInfo(k, d, Vector{Csize_t}(l))

# Helper function for getting local size for a specific sub-group count
function local_size_for_sub_group_count(ki::KernelSubGroupInfo, sub_group_count::Integer)
k = getfield(ki, :kernel)
d = getfield(ki, :device)
input_value = Ref{Csize_t}(sub_group_count)
result = Ref{NTuple{3, Csize_t}}()
clGetKernelSubGroupInfo(k, d, CL_KERNEL_LOCAL_SIZE_FOR_SUB_GROUP_COUNT,
sizeof(Csize_t), input_value, sizeof(NTuple{3, Csize_t}), result, C_NULL)
return Int.(result[])
end

function Base.getproperty(ki::KernelSubGroupInfo, s::Symbol)
k = getfield(ki, :kernel)
d = getfield(ki, :device)
lws = getfield(ki, :local_work_size)

function get(val, typ)
result = Ref{typ}()
clGetKernelSubGroupInfo(k, d, val, sizeof(lws), lws, sizeof(typ), result, C_NULL)
return result[]
end

if s == :max_sub_group_size
Int(get(CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, Csize_t))
elseif s == :sub_group_count
Int(get(CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE, Csize_t))
elseif s == :local_size_for_sub_group_count
# This requires input_value to be the desired sub-group count
error("local_size_for_sub_group_count requires specifying desired sub-group count")
elseif s == :max_num_sub_groups
Int(get(CL_KERNEL_MAX_NUM_SUB_GROUPS, Csize_t))
elseif s == :compile_num_sub_groups
Int(get(CL_KERNEL_COMPILE_NUM_SUB_GROUPS, Csize_t))
elseif s == :compile_sub_group_size
Int(get(CL_KERNEL_COMPILE_SUB_GROUP_SIZE_INTEL, Csize_t))
else
getfield(ki, s)
end
end


## kernel calling

function enqueue_kernel(k::Kernel, global_work_size, local_work_size=nothing;
global_work_offset=nothing, wait_on::Vector{Event}=Event[])
global_work_offset=nothing, wait_on::Vector{Event}=Event[],
rng_state=false, nargs=nothing)
max_work_dim = device().max_work_item_dims
work_dim = length(global_work_size)
if work_dim > max_work_dim
Expand Down Expand Up @@ -153,6 +201,20 @@ function enqueue_kernel(k::Kernel, global_work_size, local_work_size=nothing;
# null local size means OpenCL decides
end

if rng_state
if local_work_size !== nothing
num_sub_groups = KernelSubGroupInfo(k, device(), lsize).sub_group_count
else
num_sub_groups = KernelSubGroupInfo(k, device(), Csize_t[]).max_num_sub_groups
end
if nargs === nothing
nargs = k.num_args - 2
end
rng_state_size = sizeof(UInt32) * num_sub_groups
set_arg!(k, nargs + 1, LocalMem(UInt32, rng_state_size))
set_arg!(k, nargs + 2, LocalMem(UInt32, rng_state_size))
end

if !isempty(wait_on)
n_events = length(wait_on)
wait_event_ids = [evt.id for evt in wait_on]
Expand Down Expand Up @@ -189,7 +251,8 @@ end
function call(
k::Kernel, args...; global_size = (1,), local_size = nothing,
global_work_offset = nothing, wait_on::Vector{Event} = Event[],
indirect_memory::Vector{AbstractMemory} = AbstractMemory[]
indirect_memory::Vector{AbstractMemory} = AbstractMemory[],
rng_state=false,
)
set_args!(k, args...)
if !isempty(indirect_memory)
Expand Down Expand Up @@ -243,7 +306,7 @@ function call(
clSetKernelExecInfo(k, CL_KERNEL_EXEC_INFO_USM_PTRS_INTEL, sizeof(usm_pointers), usm_pointers)
end
end
enqueue_kernel(k, global_size, local_size; global_work_offset, wait_on)
enqueue_kernel(k, global_size, local_size; global_work_offset, wait_on, rng_state, nargs=length(args))
end

# From `julia/base/reflection.jl`, adjusted to add specialization on `t`.
Expand Down
1 change: 1 addition & 0 deletions src/OpenCL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Base.Experimental.@MethodTable(method_table)
include("device/runtime.jl")
include("device/array.jl")
include("device/quirks.jl")
include("device/random.jl")

# high level implementation
include("memory.jl")
Expand Down
94 changes: 89 additions & 5 deletions src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,86 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
in(fn, known_intrinsics) ||
contains(fn, "__spirv_")

GPUCompiler.kernel_state_type(::OpenCLCompilerJob) = KernelState

function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
mod::LLVM.Module, entry::LLVM.Function)
entry = invoke(GPUCompiler.finish_module!,
Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
job, mod, entry)

# if this kernel uses our RNG, we should prime the shared state.
# XXX: these transformations should really happen at the Julia IR level...
if haskey(functions(mod), "julia.opencl.random_keys") && job.config.kernel
# insert call to `initialize_rng_state`
f = initialize_rng_state
ft = typeof(f)
tt = Tuple{}

# create a deferred compilation job for `initialize_rng_state`
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
cfg = CompilerConfig(job.config; kernel=false, name=nothing)
job = CompilerJob(src, cfg, job.world)
id = length(GPUCompiler.deferred_codegen_jobs) + 1
GPUCompiler.deferred_codegen_jobs[id] = job

# generate IR for calls to `deferred_codegen` and the resulting function pointer
top_bb = first(blocks(entry))
bb = BasicBlock(top_bb, "initialize_rng")
@dispose builder=IRBuilder() begin
position!(builder, bb)
subprogram = LLVM.subprogram(entry)
if subprogram !== nothing
loc = DILocation(0, 0, subprogram)
debuglocation!(builder, loc)
end
debuglocation!(builder, first(instructions(top_bb)))

# call the `deferred_codegen` marker function
T_ptr = if LLVM.version() >= v"17"
LLVM.PointerType()
elseif VERSION >= v"1.12.0-DEV.225"
LLVM.PointerType(LLVM.Int8Type())
else
LLVM.Int64Type()
end
T_id = convert(LLVMType, Int)
deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_id])
deferred_codegen = if haskey(functions(mod), "deferred_codegen")
functions(mod)["deferred_codegen"]
else
LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft)
end
fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)])

# call the `initialize_rng_state` function
rt = Core.Compiler.return_type(f, tt)
llvm_rt = convert(LLVMType, rt)
llvm_ft = LLVM.FunctionType(llvm_rt)
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
call!(builder, llvm_ft, fptr)
br!(builder, top_bb)

# note the use of the device-side RNG in this kernel
push!(function_attributes(entry), StringAttribute("julia.opencl.rng", ""))
end

# XXX: put some of the above behind GPUCompiler abstractions
# (e.g., a compile-time version of `deferred_codegen`)
end
return entry
end

function GPUCompiler.finish_linked_module!(@nospecialize(job::OpenCLCompilerJob), mod::LLVM.Module)
for f in GPUCompiler.kernels(mod)
kernel_intrinsics = Dict(
"julia.opencl.random_keys" => (; name = "random_keys", typ = LLVMPtr{UInt32, AS.Workgroup}),
"julia.opencl.random_counters" => (; name = "random_counters", typ = LLVMPtr{UInt32, AS.Workgroup}),
)
GPUCompiler.add_input_arguments!(job, mod, f, kernel_intrinsics)
end
return
end

## compiler implementation (cache, configure, compile, and link)

Expand Down Expand Up @@ -57,13 +137,17 @@ end
# compile to executable machine code
const compilations = Threads.Atomic{Int}(0)
function compile(@nospecialize(job::CompilerJob))
compilations[] += 1

# TODO: this creates a context; cache those.
obj, meta = JuliaContext() do ctx
GPUCompiler.compile(:obj, job)
end
compilations[] += 1
obj, meta = GPUCompiler.compile(:obj, job)

(obj, entry=LLVM.name(meta.entry))
entry = LLVM.name(meta.entry)
device_rng = StringAttribute("julia.opencl.rng", "") in collect(function_attributes(meta.entry))

(; obj, entry, device_rng)
end
end

# link into an executable kernel
Expand All @@ -82,5 +166,5 @@ function link(@nospecialize(job::CompilerJob), compiled)
cl.Program(; source)
end
cl.build!(prog)
cl.Kernel(prog, compiled.entry)
(; kernel=cl.Kernel(prog, compiled.entry), compiled.device_rng)
end
17 changes: 11 additions & 6 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,15 @@ kernel_convert(arg, indirect_memory::Vector{cl.AbstractMemory} = cl.AbstractMemo

abstract type AbstractKernel{F, TT} end

pass_arg(@nospecialize dt) = !(isghosttype(dt) || Core.Compiler.isconstType(dt))

@inline @generated function (kernel::AbstractKernel{F,TT})(args...;
call_kwargs...) where {F,TT}
sig = Tuple{F, TT.parameters...} # Base.signature_type with a function type
args = (:(kernel.f), (:(kernel_convert(args[$i], indirect_memory)) for i in 1:length(args))...)

# filter out ghost arguments that shouldn't be passed
predicate = dt -> isghosttype(dt) || Core.Compiler.isconstType(dt)
to_pass = map(!predicate, sig.parameters)
to_pass = map(pass_arg, sig.parameters)
call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]]
call_args = Union{Expr,Symbol}[x[1] for x in zip(args, to_pass) if x[2]]

Expand All @@ -151,12 +152,15 @@ abstract type AbstractKernel{F, TT} end
end
end

pushfirst!(call_t, KernelState)
pushfirst!(call_args, :(KernelState(kernel.rng_state ? Base.rand(UInt32) : UInt32(0))))

# finalize types
call_tt = Base.to_tuple_type(call_t)

quote
indirect_memory = cl.AbstractMemory[]
clcall(kernel.fun, $call_tt, $(call_args...); indirect_memory, call_kwargs...)
clcall(kernel.fun, $call_tt, $(call_args...); indirect_memory, kernel.rng_state, call_kwargs...)
end
end

Expand All @@ -167,6 +171,7 @@ end
struct HostKernel{F,TT} <: AbstractKernel{F,TT}
f::F
fun::cl.Kernel
rng_state::Bool
end


Expand All @@ -183,15 +188,15 @@ function clfunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
cache = compiler_cache(ctx)
source = methodinstance(F, tt)
config = compiler_config(dev; kwargs...)::OpenCLCompilerConfig
fun = GPUCompiler.cached_compilation(cache, source, config, compile, link)
linked = GPUCompiler.cached_compilation(cache, source, config, compile, link)

# create a callable object that captures the function instance. we don't need to think
# about world age here, as GPUCompiler already does and will return a different object
h = hash(fun, hash(f, hash(tt)))
h = hash(linked.kernel, hash(f, hash(tt)))
kernel = get(_kernel_instances, h, nothing)
if kernel === nothing
# create the kernel state object
kernel = HostKernel{F,tt}(f, fun)
kernel = HostKernel{F,tt}(f, linked.kernel, linked.device_rng)
_kernel_instances[h] = kernel
end
return kernel::HostKernel{F,tt}
Expand Down
Loading