@@ -82,7 +82,7 @@ function irgen(@nospecialize(job::CompilerJob))
8282
8383 # minimal required optimization
8484 @tracepoint " rewrite" begin
85- if job. config. kernel && needs_byval (job)
85+ if job. config. kernel && pass_by_value (job)
8686 # pass all bitstypes by value; by default Julia passes aggregates by reference
8787 # (this improves performance, and is mandated by certain back-ends like SPIR-V).
8888 args = classify_arguments (job, function_type (entry))
@@ -256,10 +256,11 @@ end
256256# # kernel promotion
257257
258258@enum ArgumentCC begin
259- BITS_VALUE # bitstype, passed as value
260- BITS_REF # bitstype, passed as pointer
261- MUT_REF # jl_value_t*, or the anonymous equivalent
262- GHOST # not passed
259+ BITS_VALUE # bitstype, passed as value
260+ BITS_REF # bitstype, passed as pointer
261+ MUT_REF # jl_value_t*, or the anonymous equivalent
262+ GHOST # not passed
263+ KERNEL_STATE # the kernel state argument
263264end
264265
265266# Determine the calling convention of a the arguments of a Julia function, given the
270271# - `name`: the name of the argument
271272# - `idx`: the index of the argument in the LLVM function type, or `nothing` if the argument
272273# is not passed at the LLVM level.
273- function classify_arguments (@nospecialize (job:: CompilerJob ), codegen_ft:: LLVM.FunctionType )
274+ function classify_arguments (@nospecialize (job:: CompilerJob ), codegen_ft:: LLVM.FunctionType ;
275+ post_optimization:: Bool = false )
274276 source_sig = job. source. specTypes
275277 source_types = [source_sig. parameters... ]
276278
@@ -282,9 +284,15 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu
282284
283285 codegen_types = parameters (codegen_ft)
284286
285- args = []
286- codegen_i = 1
287- for (source_i, (source_typ, source_name)) in enumerate (zip (source_types, source_argnames))
287+ if post_optimization && kernel_state_type (job) != = Nothing
288+ args = []
289+ push! (args, (cc= KERNEL_STATE, typ= kernel_state_type (job), name= :kernel_state , idx= 1 ))
290+ codegen_i = 2
291+ else
292+ args = []
293+ codegen_i = 1
294+ end
295+ for (source_typ, source_name) in zip (source_types, source_argnames)
288296 if isghosttype (source_typ) || Core. Compiler. isconstType (source_typ)
289297 push! (args, (cc= GHOST, typ= source_typ, name= source_name, idx= nothing ))
290298 continue
@@ -817,3 +825,99 @@ function kernel_state_value(state)
817825 call_function (llvm_f, state)
818826 end
819827end
828+
829+ # convert kernel state argument from pass-by-value to pass-by-reference
830+ #
831+ # the kernel state argument is always passed by value to avoid codegen issues with byval.
832+ # some back-ends however do not support passing kernel arguments by value, so this pass
833+ # serves to convert that argument (and is conceptually the inverse of `lower_byval`).
834+ function kernel_state_to_reference! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
835+ f:: LLVM.Function )
836+ ft = function_type (f)
837+
838+ # check if we even need a kernel state argument
839+ state = kernel_state_type (job)
840+ if state === Nothing
841+ return f
842+ end
843+
844+ T_state = convert (LLVMType, state)
845+
846+ # find the kernel state parameter (should be the first argument)
847+ if isempty (parameters (ft)) || value_type (parameters (f)[1 ]) != T_state
848+ return f
849+ end
850+
851+ @tracepoint " kernel state to reference" begin
852+ # generate the new function type & definition
853+ new_types = LLVM. LLVMType[]
854+ # convert the first parameter (kernel state) to a pointer
855+ push! (new_types, LLVM. PointerType (T_state))
856+ # keep all other parameters as-is
857+ for i in 2 : length (parameters (ft))
858+ push! (new_types, parameters (ft)[i])
859+ end
860+
861+ new_ft = LLVM. FunctionType (return_type (ft), new_types)
862+ new_f = LLVM. Function (mod, " " , new_ft)
863+ linkage! (new_f, linkage (f))
864+
865+ # name the parameters
866+ LLVM. name! (parameters (new_f)[1 ], " state_ptr" )
867+ for (i, (arg, new_arg)) in enumerate (zip (parameters (f)[2 : end ], parameters (new_f)[2 : end ]))
868+ LLVM. name! (new_arg, LLVM. name (arg))
869+ end
870+
871+ # emit IR performing the "conversions"
872+ new_args = LLVM. Value[]
873+ @dispose builder= IRBuilder () begin
874+ entry = BasicBlock (new_f, " conversion" )
875+ position! (builder, entry)
876+
877+ # load the kernel state value from the pointer
878+ state_val = load! (builder, T_state, parameters (new_f)[1 ], " state" )
879+ push! (new_args, state_val)
880+
881+ # all other arguments are passed through directly
882+ for i in 2 : length (parameters (new_f))
883+ push! (new_args, parameters (new_f)[i])
884+ end
885+
886+ # map the arguments
887+ value_map = Dict {LLVM.Value, LLVM.Value} (
888+ param => new_args[i] for (i,param) in enumerate (parameters (f))
889+ )
890+ value_map[f] = new_f
891+
892+ clone_into! (new_f, f; value_map,
893+ changes= LLVM. API. LLVMCloneFunctionChangeTypeGlobalChanges)
894+
895+ # fall through
896+ br! (builder, blocks (new_f)[2 ])
897+ end
898+
899+ # set the attributes for the state pointer parameter
900+ attrs = parameter_attributes (new_f, 1 )
901+ # the pointer itself cannot be captured since we immediately load from it
902+ push! (attrs, EnumAttribute (" nocapture" , 0 ))
903+ # each kernel state is separate
904+ push! (attrs, EnumAttribute (" noalias" , 0 ))
905+ # the state is read-only
906+ push! (attrs, EnumAttribute (" readonly" , 0 ))
907+
908+ # remove the old function
909+ fn = LLVM. name (f)
910+ @assert isempty (uses (f))
911+ replace_metadata_uses! (f, new_f)
912+ erase! (f)
913+ LLVM. name! (new_f, fn)
914+
915+ # minimal optimization
916+ @dispose pb= NewPMPassBuilder () begin
917+ add! (pb, SimplifyCFGPass ())
918+ run! (pb, new_f, llvm_machine (job. config. target))
919+ end
920+
921+ return new_f
922+ end
923+ end
0 commit comments