@@ -82,7 +82,7 @@ function irgen(@nospecialize(job::CompilerJob))
8282
8383 # minimal required optimization
8484 @tracepoint " rewrite" begin
85- if job. config. kernel && needs_byval (job)
85+ if job. config. kernel && pass_by_value (job)
8686 # pass all bitstypes by value; by default Julia passes aggregates by reference
8787 # (this improves performance, and is mandated by certain back-ends like SPIR-V).
8888 args = classify_arguments (job, function_type (entry))
@@ -256,10 +256,11 @@ end
256256# # kernel promotion
257257
258258@enum ArgumentCC begin
259- BITS_VALUE # bitstype, passed as value
260- BITS_REF # bitstype, passed as pointer
261- MUT_REF # jl_value_t*, or the anonymous equivalent
262- GHOST # not passed
259+ BITS_VALUE # bitstype, passed as value
260+ BITS_REF # bitstype, passed as pointer
261+ MUT_REF # jl_value_t*, or the anonymous equivalent
262+ GHOST # not passed
263+ KERNEL_STATE # the kernel state argument
263264end
264265
265266# Determine the calling convention of a the arguments of a Julia function, given the
270271# - `name`: the name of the argument
271272# - `idx`: the index of the argument in the LLVM function type, or `nothing` if the argument
272273# is not passed at the LLVM level.
273- function classify_arguments (@nospecialize (job:: CompilerJob ), codegen_ft:: LLVM.FunctionType )
274+ function classify_arguments (@nospecialize (job:: CompilerJob ), codegen_ft:: LLVM.FunctionType ;
275+ post_optimization:: Bool = false )
274276 source_sig = job. source. specTypes
275277 source_types = [source_sig. parameters... ]
276278
@@ -282,9 +284,15 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu
282284
283285 codegen_types = parameters (codegen_ft)
284286
285- args = []
286- codegen_i = 1
287- for (source_i, (source_typ, source_name)) in enumerate (zip (source_types, source_argnames))
287+ if post_optimization && kernel_state_type (job) != = Nothing
288+ args = []
289+ push! (args, (cc= KERNEL_STATE, typ= kernel_state_type (job), name= :kernel_state , idx= 1 ))
290+ codegen_i = 2
291+ else
292+ args = []
293+ codegen_i = 1
294+ end
295+ for (source_typ, source_name) in zip (source_types, source_argnames)
288296 if isghosttype (source_typ) || Core. Compiler. isconstType (source_typ)
289297 push! (args, (cc= GHOST, typ= source_typ, name= source_name, idx= nothing ))
290298 continue
@@ -817,3 +825,256 @@ function kernel_state_value(state)
817825 call_function (llvm_f, state)
818826 end
819827end
828+
829+ # convert kernel state argument from pass-by-value to pass-by-reference
830+ #
831+ # the kernel state argument is always passed by value to avoid codegen issues with byval.
832+ # some back-ends however do not support passing kernel arguments by value, so this pass
833+ # serves to convert that argument (and is conceptually the inverse of `lower_byval`).
834+ function kernel_state_to_reference! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
835+ f:: LLVM.Function )
836+ ft = function_type (f)
837+
838+ # check if we even need a kernel state argument
839+ state = kernel_state_type (job)
840+ if state === Nothing
841+ return f
842+ end
843+
844+ T_state = convert (LLVMType, state)
845+
846+ # find the kernel state parameter (should be the first argument)
847+ if isempty (parameters (ft)) || value_type (parameters (f)[1 ]) != T_state
848+ return f
849+ end
850+
851+ @tracepoint " kernel state to reference" begin
852+ # generate the new function type & definition
853+ new_types = LLVM. LLVMType[]
854+ # convert the first parameter (kernel state) to a pointer
855+ push! (new_types, LLVM. PointerType (T_state))
856+ # keep all other parameters as-is
857+ for i in 2 : length (parameters (ft))
858+ push! (new_types, parameters (ft)[i])
859+ end
860+
861+ new_ft = LLVM. FunctionType (return_type (ft), new_types)
862+ new_f = LLVM. Function (mod, " " , new_ft)
863+ linkage! (new_f, linkage (f))
864+
865+ # name the parameters
866+ LLVM. name! (parameters (new_f)[1 ], " state_ptr" )
867+ for (i, (arg, new_arg)) in enumerate (zip (parameters (f)[2 : end ], parameters (new_f)[2 : end ]))
868+ LLVM. name! (new_arg, LLVM. name (arg))
869+ end
870+
871+ # emit IR performing the "conversions"
872+ new_args = LLVM. Value[]
873+ @dispose builder= IRBuilder () begin
874+ entry = BasicBlock (new_f, " conversion" )
875+ position! (builder, entry)
876+
877+ # load the kernel state value from the pointer
878+ state_val = load! (builder, T_state, parameters (new_f)[1 ], " state" )
879+ push! (new_args, state_val)
880+
881+ # all other arguments are passed through directly
882+ for i in 2 : length (parameters (new_f))
883+ push! (new_args, parameters (new_f)[i])
884+ end
885+
886+ # map the arguments
887+ value_map = Dict {LLVM.Value, LLVM.Value} (
888+ param => new_args[i] for (i,param) in enumerate (parameters (f))
889+ )
890+ value_map[f] = new_f
891+
892+ clone_into! (new_f, f; value_map,
893+ changes= LLVM. API. LLVMCloneFunctionChangeTypeGlobalChanges)
894+
895+ # fall through
896+ br! (builder, blocks (new_f)[2 ])
897+ end
898+
899+ # set the attributes for the state pointer parameter
900+ attrs = parameter_attributes (new_f, 1 )
901+ # the pointer itself cannot be captured since we immediately load from it
902+ push! (attrs, EnumAttribute (" nocapture" , 0 ))
903+ # each kernel state is separate
904+ push! (attrs, EnumAttribute (" noalias" , 0 ))
905+ # the state is read-only
906+ push! (attrs, EnumAttribute (" readonly" , 0 ))
907+
908+ # remove the old function
909+ fn = LLVM. name (f)
910+ @assert isempty (uses (f))
911+ replace_metadata_uses! (f, new_f)
912+ erase! (f)
913+ LLVM. name! (new_f, fn)
914+
915+ # minimal optimization
916+ @dispose pb= NewPMPassBuilder () begin
917+ add! (pb, SimplifyCFGPass ())
918+ run! (pb, new_f, llvm_machine (job. config. target))
919+ end
920+
921+ return new_f
922+ end
923+ end
924+
925+ function add_input_arguments! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
926+ entry:: LLVM.Function , kernel_intrinsics:: Dict )
927+ entry_fn = LLVM. name (entry)
928+
929+ # figure out which intrinsics are used and need to be added as arguments
930+ used_intrinsics = filter (keys (kernel_intrinsics)) do intr_fn
931+ haskey (functions (mod), intr_fn)
932+ end |> collect
933+ nargs = length (used_intrinsics)
934+
935+ # determine which functions need these arguments
936+ worklist = Set {LLVM.Function} ([entry])
937+ for intr_fn in used_intrinsics
938+ push! (worklist, functions (mod)[intr_fn])
939+ end
940+ worklist_length = 0
941+ while worklist_length != length (worklist)
942+ # iteratively discover functions that use an intrinsic or any function calling it
943+ worklist_length = length (worklist)
944+ additions = LLVM. Function[]
945+ for f in worklist, use in uses (f)
946+ inst = user (use):: Instruction
947+ bb = LLVM. parent (inst)
948+ new_f = LLVM. parent (bb)
949+ in (new_f, worklist) || push! (additions, new_f)
950+ end
951+ for f in additions
952+ push! (worklist, f)
953+ end
954+ end
955+ for intr_fn in used_intrinsics
956+ delete! (worklist, functions (mod)[intr_fn])
957+ end
958+
959+ # add the arguments
960+ # NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
961+ workmap = Dict {LLVM.Function, LLVM.Function} ()
962+ for f in worklist
963+ fn = LLVM. name (f)
964+ ft = function_type (f)
965+ LLVM. name! (f, fn * " .orig" )
966+ # create a new function
967+ new_param_types = LLVMType[parameters (ft)... ]
968+
969+ for intr_fn in used_intrinsics
970+ llvm_typ = convert (LLVMType, kernel_intrinsics[intr_fn]. typ)
971+ push! (new_param_types, llvm_typ)
972+ end
973+ new_ft = LLVM. FunctionType (return_type (ft), new_param_types)
974+ new_f = LLVM. Function (mod, fn, new_ft)
975+ linkage! (new_f, linkage (f))
976+ for (arg, new_arg) in zip (parameters (f), parameters (new_f))
977+ LLVM. name! (new_arg, LLVM. name (arg))
978+ end
979+ for (intr_fn, new_arg) in zip (used_intrinsics, parameters (new_f)[end - nargs+ 1 : end ])
980+ LLVM. name! (new_arg, kernel_intrinsics[intr_fn]. name)
981+ end
982+
983+ workmap[f] = new_f
984+ end
985+
986+ # clone and rewrite the function bodies.
987+ # we don't need to rewrite much as the arguments are added last.
988+ for (f, new_f) in workmap
989+ # map the arguments
990+ value_map = Dict {LLVM.Value, LLVM.Value} ()
991+ for (param, new_param) in zip (parameters (f), parameters (new_f))
992+ LLVM. name! (new_param, LLVM. name (param))
993+ value_map[param] = new_param
994+ end
995+
996+ value_map[f] = new_f
997+ clone_into! (new_f, f; value_map,
998+ changes= LLVM. API. LLVMCloneFunctionChangeTypeLocalChangesOnly)
999+
1000+ # we can't remove this function yet, as we might still need to rewrite any called,
1001+ # but remove the IR already
1002+ empty! (f)
1003+ end
1004+
1005+ # drop unused constants that may be referring to the old functions
1006+ # XXX : can we do this differently?
1007+ for f in worklist
1008+ prune_constexpr_uses! (f)
1009+ end
1010+
1011+ # update other uses of the old function, modifying call sites to pass the arguments
1012+ function rewrite_uses! (f, new_f)
1013+ # update uses
1014+ @dispose builder= IRBuilder () begin
1015+ for use in uses (f)
1016+ val = user (use)
1017+ if val isa LLVM. CallInst || val isa LLVM. InvokeInst || val isa LLVM. CallBrInst
1018+ callee_f = LLVM. parent (LLVM. parent (val))
1019+ # forward the arguments
1020+ position! (builder, val)
1021+ new_val = if val isa LLVM. CallInst
1022+ call! (builder, function_type (new_f), new_f,
1023+ [arguments (val)... , parameters (callee_f)[end - nargs+ 1 : end ]. .. ],
1024+ operand_bundles (val))
1025+ else
1026+ # TODO : invoke and callbr
1027+ error (" Rewrite of $(typeof (val)) -based calls is not implemented: $val " )
1028+ end
1029+ callconv! (new_val, callconv (val))
1030+
1031+ replace_uses! (val, new_val)
1032+ @assert isempty (uses (val))
1033+ erase! (val)
1034+ elseif val isa LLVM. ConstantExpr && opcode (val) == LLVM. API. LLVMBitCast
1035+ # XXX : why isn't this caught by the value materializer above?
1036+ target = operands (val)[1 ]
1037+ @assert target == f
1038+ new_val = LLVM. const_bitcast (new_f, value_type (val))
1039+ rewrite_uses! (val, new_val)
1040+ # we can't simply replace this constant expression, as it may be used
1041+ # as a call, taking arguments (so we need to rewrite it to pass the input arguments)
1042+
1043+ # drop the old constant if it is unused
1044+ # XXX : can we do this differently?
1045+ if isempty (uses (val))
1046+ LLVM. unsafe_destroy! (val)
1047+ end
1048+ else
1049+ error (" Cannot rewrite unknown use of function: $val " )
1050+ end
1051+ end
1052+ end
1053+ end
1054+ for (f, new_f) in workmap
1055+ rewrite_uses! (f, new_f)
1056+ @assert isempty (uses (f))
1057+ erase! (f)
1058+ end
1059+
1060+ # replace uses of the intrinsics with references to the input arguments
1061+ for (i, intr_fn) in enumerate (used_intrinsics)
1062+ intr = functions (mod)[intr_fn]
1063+ for use in uses (intr)
1064+ val = user (use)
1065+ callee_f = LLVM. parent (LLVM. parent (val))
1066+ if val isa LLVM. CallInst || val isa LLVM. InvokeInst || val isa LLVM. CallBrInst
1067+ replace_uses! (val, parameters (callee_f)[end - nargs+ i])
1068+ else
1069+ error (" Cannot rewrite unknown use of function: $val " )
1070+ end
1071+
1072+ @assert isempty (uses (val))
1073+ erase! (val)
1074+ end
1075+ @assert isempty (uses (intr))
1076+ erase! (intr)
1077+ end
1078+
1079+ return functions (mod)[entry_fn]
1080+ end
0 commit comments