@@ -530,8 +530,8 @@ function add_kernel_state!(mod::LLVM.Module)
530530
531531 # additional arguments to pass to every function, but only if they are required
532532 additional_args = haskey (functions (mod), " julia.gpu.additional_arg_getter" ) ? additional_arg_types (job) : (;)
533- T_additional_args = convert . (LLVMType, values (additional_args))
534- names_additional_args = String .( keys (additional_args))
533+ T_additional_args = LLVMType[ convert (LLVMType, T) for T in values (additional_args)]
534+ names_additional_args = String[ String (name) for name in keys (additional_args)]
535535
536536 additional_arg_intrs = additional_arg_intr .(Ref (mod), T_additional_args)
537537 additional_arg_intr_fts = LLVM. FunctionType .(T_additional_args)
@@ -804,13 +804,12 @@ function lower_kernel_state!(fun::LLVM.Function)
804804
805805 i = Int (convert (Int, operands (inst)[1 ]:: ConstantInt ))
806806 if additional_args[i] === nothing
807- state_arg = parameters (fun)[end - length (additional_arg_tys) + i]
808- T_state = convert (LLVMType, additional_arg_tys[i])
809- @assert value_type (state_arg) == T_state
810- additional_args[i] = state_arg
807+ additional_args[i] = parameters (fun)[end - length (additional_arg_tys) + i]
808+ T_arg = convert (LLVMType, additional_arg_tys[i])
809+ @assert value_type (additional_args[i]) == T_arg
811810 end
812811
813- replace_uses! (inst, state_arg )
812+ replace_uses! (inst, additional_args[i] )
814813
815814 @assert isempty (uses (inst))
816815 erase! (inst)
@@ -841,7 +840,6 @@ function cleanup_kernel_state!(mod::LLVM.Module)
841840 if haskey (functions (mod), " julia.gpu.additional_arg_getter" )
842841 intr = functions (mod)[" julia.gpu.additional_arg_getter" ]
843842 if isempty (uses (intr))
844- # if we're not emitting a kernel, we can't resolve the intrinsic to an argument.
845843 erase! (intr)
846844 changed = true
847845 end
@@ -985,40 +983,40 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
985983 end
986984end
987985
988- function additional_arg_intr (mod:: LLVM.Module , T_state )
989- state_intr = if haskey (functions (mod), " julia.gpu.additional_arg_getter" )
986+ function additional_arg_intr (mod:: LLVM.Module , T_arg )
987+ additional_arg_intr = if haskey (functions (mod), " julia.gpu.additional_arg_getter" )
990988 functions (mod)[" julia.gpu.additional_arg_getter" ]
991989 else
992- LLVM. Function (mod, " julia.gpu.additional_arg_getter" , LLVM. FunctionType (T_state ))
990+ LLVM. Function (mod, " julia.gpu.additional_arg_getter" , LLVM. FunctionType (T_arg ))
993991 end
994- push! (function_attributes (state_intr ), EnumAttribute (" readnone" , 0 ))
992+ push! (function_attributes (additional_arg_intr ), EnumAttribute (" readnone" , 0 ))
995993
996- return state_intr
994+ return additional_arg_intr
997995end
998996
999997# run-time equivalent
1000- function additional_arg_value (state , index)
998+ function additional_arg_value (arg , index)
1001999 @dispose ctx= Context () begin
1002- T_state = convert (LLVMType, state )
1000+ T_arg = convert (LLVMType, arg )
10031001
10041002 # create function
1005- llvm_f, _ = create_function (T_state )
1003+ llvm_f, _ = create_function (T_arg )
10061004 mod = LLVM. parent (llvm_f)
10071005
10081006 # get intrinsic
1009- state_intr = additional_arg_intr (mod, T_state )
1010- state_intr_ft = function_type (state_intr )
1007+ _additional_arg_intr = additional_arg_intr (mod, T_arg )
1008+ additional_arg_intr_ft = function_type (_additional_arg_intr )
10111009
10121010 # generate IR
10131011 @dispose builder= IRBuilder () begin
10141012 entry = BasicBlock (llvm_f, " entry" )
10151013 position! (builder, entry)
10161014
1017- val = call! (builder, state_intr_ft, state_intr , Value[ConstantInt (index)], " additional_arg" )
1015+ val = call! (builder, additional_arg_intr_ft, _additional_arg_intr , Value[ConstantInt (index)], " additional_arg" )
10181016
10191017 ret! (builder, val)
10201018 end
10211019
1022- call_function (llvm_f, state )
1020+ call_function (llvm_f, arg )
10231021 end
10241022end
0 commit comments