diff --git a/src/interface.jl b/src/interface.jl index 157f1977..69998018 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -265,6 +265,7 @@ runtime_slug(@nospecialize(job::CompilerJob)) = error("Not implemented") # argument to each kernel, and pass that object to every function that accesses the kernel # state (possibly indirectly) via the `kernel_state_pointer` function. kernel_state_type(@nospecialize(job::CompilerJob)) = Nothing +additional_arg_types(@nospecialize(job::CompilerJob)) = (;) # Does the target need to pass kernel arguments by value? pass_by_value(@nospecialize(job::CompilerJob)) = true diff --git a/src/irgen.jl b/src/irgen.jl index 63c5057d..778a169c 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -528,6 +528,14 @@ function add_kernel_state!(mod::LLVM.Module) state_intr = kernel_state_intr(mod, T_state) state_intr_ft = LLVM.FunctionType(T_state) + # additional arguments to pass to every function, but only if they are required + additional_args = haskey(functions(mod), "julia.gpu.additional_arg_getter") ? additional_arg_types(job) : (;) + T_additional_args = LLVMType[convert(LLVMType, T) for T in values(additional_args)] + names_additional_args = String[String(name) for name in keys(additional_args)] + + additional_arg_intrs = [additional_arg_intr(mod, T) for T in T_additional_args] + additional_arg_intr_fts = [LLVM.FunctionType(T, [convert(LLVMType, Int)]) for T in T_additional_args] + kernels = [] kernels_md = metadata(mod)["julia.kernel"] for kernel_md in operands(kernels_md) @@ -539,7 +547,7 @@ function add_kernel_state!(mod::LLVM.Module) # previously, we add the argument to every function and relied on unused arg elim to # clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting # function pointers. such IR is hard to rewrite, so instead be more conservative. - worklist = Set{LLVM.Function}([state_intr, kernels...]) + worklist = Set{LLVM.Function}([state_intr, additional_arg_intrs..., kernels...]) worklist_length = 0 while worklist_length != length(worklist) # iteratively discover functions that use the intrinsic or any function calling it @@ -567,6 +575,9 @@ function add_kernel_state!(mod::LLVM.Module) end end delete!(worklist, state_intr) + for intr in additional_arg_intrs + delete!(worklist, intr) + end # add a state argument workmap = Dict{LLVM.Function, LLVM.Function}() @@ -576,7 +587,7 @@ function add_kernel_state!(mod::LLVM.Module) LLVM.name!(f, fn * ".stateless") # create a new function - new_param_types = [T_state, parameters(ft)...] + new_param_types = [T_state, parameters(ft)..., T_additional_args...] new_ft = LLVM.FunctionType(return_type(ft), new_param_types) new_f = LLVM.Function(mod, fn, new_ft) LLVM.name!(parameters(new_f)[1], "state") @@ -584,6 +595,9 @@ function add_kernel_state!(mod::LLVM.Module) for (arg, new_arg) in zip(parameters(f), parameters(new_f)[2:end]) LLVM.name!(new_arg, LLVM.name(arg)) end + for (name, new_arg) in zip(names_additional_args, parameters(new_f)[(2 + length(parameters(ft))):end]) + LLVM.name!(new_arg, name) + end workmap[f] = new_f end @@ -609,7 +623,7 @@ function add_kernel_state!(mod::LLVM.Module) # is all this even sound? typ = value_type(val)::LLVM.PointerType ft = eltype(typ)::LLVM.FunctionType - new_ft = LLVM.FunctionType(return_type(ft), [T_state, parameters(ft)...]) + new_ft = LLVM.FunctionType(return_type(ft), [T_state, parameters(ft)..., T_additional_args...]) return const_bitcast(workmap[target], LLVM.PointerType(new_ft, addrspace(typ))) end elseif opcode(val) == LLVM.API.LLVMPtrToInt @@ -668,8 +682,12 @@ function add_kernel_state!(mod::LLVM.Module) # forward the state argument position!(builder, val) state = call!(builder, state_intr_ft, state_intr, Value[], "state") + additional_args = Value[ + call!(builder, additional_arg_intr_fts[i], additional_arg_intrs[i], Value[ConstantInt(i)], names_additional_args[i]) + for i in 1:length(additional_arg_intrs) + ] new_val = if val isa LLVM.CallInst - call!(builder, ft, f, [state, arguments(val)...], operand_bundles(val)) + call!(builder, ft, f, [state, arguments(val)..., additional_args...], operand_bundles(val)) else # TODO: invoke and callbr error("Rewrite of $(typeof(val))-based calls is not implemented: $val") @@ -768,6 +786,39 @@ function lower_kernel_state!(fun::LLVM.Function) end end + additional_arg_tys = additional_arg_types(job) + if haskey(functions(mod), "julia.gpu.additional_arg_getter") + additional_arg_intr = functions(mod)["julia.gpu.additional_arg_getter"] + additional_args = Union{Value, Nothing}[nothing for i in 1:length(additional_arg_tys)] # only look-up when needed + + @dispose builder=IRBuilder() begin + for use in uses(additional_arg_intr) + inst = user(use) + @assert inst isa LLVM.CallInst + bb = LLVM.parent(inst) + LLVM.parent(bb) == fun || continue + + position!(builder, inst) + bb = LLVM.parent(inst) + f = LLVM.parent(bb) + + i = Int(convert(Int, operands(inst)[1]::ConstantInt)) + if additional_args[i] === nothing + additional_args[i] = parameters(fun)[end - length(additional_arg_tys) + i] + T_arg = convert(LLVMType, additional_arg_tys[i]) + @assert value_type(additional_args[i]) == T_arg + end + + replace_uses!(inst, additional_args[i]) + + @assert isempty(uses(inst)) + erase!(inst) + + changed = true + end + end + end + return changed end LowerKernelStatePass() = NewPMFunctionPass("LowerKernelStatePass", lower_kernel_state!) @@ -786,6 +837,14 @@ function cleanup_kernel_state!(mod::LLVM.Module) end end + if haskey(functions(mod), "julia.gpu.additional_arg_getter") + intr = functions(mod)["julia.gpu.additional_arg_getter"] + if isempty(uses(intr)) + erase!(intr) + changed = true + end + end + return changed end CleanupKernelStatePass() = NewPMModulePass("CleanupKernelStatePass", cleanup_kernel_state!) @@ -923,3 +982,41 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M return new_f end end + +function additional_arg_intr(mod::LLVM.Module, T_arg) + additional_arg_intr = if haskey(functions(mod), "julia.gpu.additional_arg_getter") + functions(mod)["julia.gpu.additional_arg_getter"] + else + LLVM.Function(mod, "julia.gpu.additional_arg_getter", LLVM.FunctionType(T_arg, [convert(LLVMType, Int)])) + end + push!(function_attributes(additional_arg_intr), EnumAttribute("readnone", 0)) + + return additional_arg_intr +end + +# run-time equivalent +function additional_arg_value(arg, index::Int) + @dispose ctx=Context() begin + T_arg = convert(LLVMType, arg) + + # create function + llvm_f, _ = create_function(T_arg) + mod = LLVM.parent(llvm_f) + + # get intrinsic + _additional_arg_intr = additional_arg_intr(mod, T_arg) + additional_arg_intr_ft = function_type(_additional_arg_intr) + + # generate IR + @dispose builder=IRBuilder() begin + entry = BasicBlock(llvm_f, "entry") + position!(builder, entry) + + val = call!(builder, additional_arg_intr_ft, _additional_arg_intr, Value[ConstantInt(index)], "additional_arg") + + ret!(builder, val) + end + + call_function(llvm_f, arg) + end +end