Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ runtime_slug(@nospecialize(job::CompilerJob)) = error("Not implemented")
# argument to each kernel, and pass that object to every function that accesses the kernel
# state (possibly indirectly) via the `kernel_state_pointer` function.
kernel_state_type(@nospecialize(job::CompilerJob)) = Nothing
additional_arg_types(@nospecialize(job::CompilerJob)) = (;)

# Does the target need to pass kernel arguments by value?
pass_by_value(@nospecialize(job::CompilerJob)) = true
Expand Down
105 changes: 101 additions & 4 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,14 @@ function add_kernel_state!(mod::LLVM.Module)
state_intr = kernel_state_intr(mod, T_state)
state_intr_ft = LLVM.FunctionType(T_state)

# additional arguments to pass to every function, but only if they are required
additional_args = haskey(functions(mod), "julia.gpu.additional_arg_getter") ? additional_arg_types(job) : (;)
T_additional_args = LLVMType[convert(LLVMType, T) for T in values(additional_args)]
names_additional_args = String[String(name) for name in keys(additional_args)]

additional_arg_intrs = [additional_arg_intr(mod, T) for T in T_additional_args]
additional_arg_intr_fts = [LLVM.FunctionType(T, [convert(LLVMType, Int)]) for T in T_additional_args]

kernels = []
kernels_md = metadata(mod)["julia.kernel"]
for kernel_md in operands(kernels_md)
Expand All @@ -539,7 +547,7 @@ function add_kernel_state!(mod::LLVM.Module)
# previously, we add the argument to every function and relied on unused arg elim to
# clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting
# function pointers. such IR is hard to rewrite, so instead be more conservative.
worklist = Set{LLVM.Function}([state_intr, kernels...])
worklist = Set{LLVM.Function}([state_intr, additional_arg_intrs..., kernels...])
worklist_length = 0
while worklist_length != length(worklist)
# iteratively discover functions that use the intrinsic or any function calling it
Expand Down Expand Up @@ -567,6 +575,9 @@ function add_kernel_state!(mod::LLVM.Module)
end
end
delete!(worklist, state_intr)
for intr in additional_arg_intrs
delete!(worklist, intr)
end

# add a state argument
workmap = Dict{LLVM.Function, LLVM.Function}()
Expand All @@ -576,14 +587,17 @@ function add_kernel_state!(mod::LLVM.Module)
LLVM.name!(f, fn * ".stateless")

# create a new function
new_param_types = [T_state, parameters(ft)...]
new_param_types = [T_state, parameters(ft)..., T_additional_args...]
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
new_f = LLVM.Function(mod, fn, new_ft)
LLVM.name!(parameters(new_f)[1], "state")
linkage!(new_f, linkage(f))
for (arg, new_arg) in zip(parameters(f), parameters(new_f)[2:end])
LLVM.name!(new_arg, LLVM.name(arg))
end
for (name, new_arg) in zip(names_additional_args, parameters(new_f)[(2 + length(parameters(ft))):end])
LLVM.name!(new_arg, name)
end

workmap[f] = new_f
end
Expand All @@ -609,7 +623,7 @@ function add_kernel_state!(mod::LLVM.Module)
# is all this even sound?
typ = value_type(val)::LLVM.PointerType
ft = eltype(typ)::LLVM.FunctionType
new_ft = LLVM.FunctionType(return_type(ft), [T_state, parameters(ft)...])
new_ft = LLVM.FunctionType(return_type(ft), [T_state, parameters(ft)..., T_additional_args...])
return const_bitcast(workmap[target], LLVM.PointerType(new_ft, addrspace(typ)))
end
elseif opcode(val) == LLVM.API.LLVMPtrToInt
Expand Down Expand Up @@ -668,8 +682,12 @@ function add_kernel_state!(mod::LLVM.Module)
# forward the state argument
position!(builder, val)
state = call!(builder, state_intr_ft, state_intr, Value[], "state")
additional_args = Value[
call!(builder, additional_arg_intr_fts[i], additional_arg_intrs[i], Value[ConstantInt(i)], names_additional_args[i])
for i in 1:length(additional_arg_intrs)
]
new_val = if val isa LLVM.CallInst
call!(builder, ft, f, [state, arguments(val)...], operand_bundles(val))
call!(builder, ft, f, [state, arguments(val)..., additional_args...], operand_bundles(val))
else
# TODO: invoke and callbr
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
Expand Down Expand Up @@ -768,6 +786,39 @@ function lower_kernel_state!(fun::LLVM.Function)
end
end

additional_arg_tys = additional_arg_types(job)
if haskey(functions(mod), "julia.gpu.additional_arg_getter")
additional_arg_intr = functions(mod)["julia.gpu.additional_arg_getter"]
additional_args = Union{Value, Nothing}[nothing for i in 1:length(additional_arg_tys)] # only look-up when needed

@dispose builder=IRBuilder() begin
for use in uses(additional_arg_intr)
inst = user(use)
@assert inst isa LLVM.CallInst
bb = LLVM.parent(inst)
LLVM.parent(bb) == fun || continue

position!(builder, inst)
bb = LLVM.parent(inst)
f = LLVM.parent(bb)

i = Int(convert(Int, operands(inst)[1]::ConstantInt))
if additional_args[i] === nothing
additional_args[i] = parameters(fun)[end - length(additional_arg_tys) + i]
T_arg = convert(LLVMType, additional_arg_tys[i])
@assert value_type(additional_args[i]) == T_arg
end

replace_uses!(inst, additional_args[i])

@assert isempty(uses(inst))
erase!(inst)

changed = true
end
end
end

return changed
end
LowerKernelStatePass() = NewPMFunctionPass("LowerKernelStatePass", lower_kernel_state!)
Expand All @@ -786,6 +837,14 @@ function cleanup_kernel_state!(mod::LLVM.Module)
end
end

if haskey(functions(mod), "julia.gpu.additional_arg_getter")
intr = functions(mod)["julia.gpu.additional_arg_getter"]
if isempty(uses(intr))
erase!(intr)
changed = true
end
end

return changed
end
CleanupKernelStatePass() = NewPMModulePass("CleanupKernelStatePass", cleanup_kernel_state!)
Expand Down Expand Up @@ -923,3 +982,41 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
return new_f
end
end

function additional_arg_intr(mod::LLVM.Module, T_arg)
additional_arg_intr = if haskey(functions(mod), "julia.gpu.additional_arg_getter")
functions(mod)["julia.gpu.additional_arg_getter"]
else
LLVM.Function(mod, "julia.gpu.additional_arg_getter", LLVM.FunctionType(T_arg, [convert(LLVMType, Int)]))
end
push!(function_attributes(additional_arg_intr), EnumAttribute("readnone", 0))

return additional_arg_intr
end

# run-time equivalent
function additional_arg_value(arg, index::Int)
@dispose ctx=Context() begin
T_arg = convert(LLVMType, arg)

# create function
llvm_f, _ = create_function(T_arg)
mod = LLVM.parent(llvm_f)

# get intrinsic
_additional_arg_intr = additional_arg_intr(mod, T_arg)
additional_arg_intr_ft = function_type(_additional_arg_intr)

# generate IR
@dispose builder=IRBuilder() begin
entry = BasicBlock(llvm_f, "entry")
position!(builder, entry)

val = call!(builder, additional_arg_intr_ft, _additional_arg_intr, Value[ConstantInt(index)], "additional_arg")

ret!(builder, val)
end

call_function(llvm_f, arg)
end
end
Loading