@@ -790,44 +790,47 @@ function cleanup_kernel_state!(mod::LLVM.Module)
790790end
791791CleanupKernelStatePass () = NewPMModulePass (" CleanupKernelStatePass" , cleanup_kernel_state!)
792792
793- function kernel_state_intr (mod:: LLVM.Module , T_state )
794- state_intr = if haskey (functions (mod), " julia.gpu.state_getter " )
795- functions (mod)[" julia.gpu.state_getter " ]
793+ function custom_intr (mod:: LLVM.Module , T :: LLVMType , name :: String )
794+ custom_intr = if haskey (functions (mod), name )
795+ functions (mod)[name ]
796796 else
797- LLVM. Function (mod, " julia.gpu.state_getter " , LLVM. FunctionType (T_state ))
797+ LLVM. Function (mod, name , LLVM. FunctionType (T ))
798798 end
799- push! (function_attributes (state_intr ), EnumAttribute (" readnone" , 0 ))
799+ push! (function_attributes (custom_intr ), EnumAttribute (" readnone" , 0 ))
800800
801- return state_intr
801+ return custom_intr
802802end
803803
804804# run-time equivalent
805- function kernel_state_value (state )
805+ function call_custom_intrinsic (T :: Type , name :: String , call_name :: String )
806806 @dispose ctx= Context () begin
807- T_state = convert (LLVMType, state )
807+ T_llvm = convert (LLVMType, T )
808808
809809 # create function
810- llvm_f, _ = create_function (T_state )
810+ llvm_f, _ = create_function (T_llvm )
811811 mod = LLVM. parent (llvm_f)
812812
813813 # get intrinsic
814- state_intr = kernel_state_intr (mod, T_state )
815- state_intr_ft = function_type (state_intr )
814+ _custom_intr = custom_intr (mod, T_llvm, name )
815+ custom_intr_ft = function_type (_custom_intr )
816816
817817 # generate IR
818818 @dispose builder= IRBuilder () begin
819819 entry = BasicBlock (llvm_f, " entry" )
820820 position! (builder, entry)
821821
822- val = call! (builder, state_intr_ft, state_intr , Value[], " state " )
822+ val = call! (builder, custom_intr_ft, _custom_intr , Value[], call_name )
823823
824824 ret! (builder, val)
825825 end
826826
827- call_function (llvm_f, state )
827+ call_function (llvm_f, T )
828828 end
829829end
830830
831+ kernel_state_intr (mod:: LLVM.Module , T:: LLVMType ) = custom_intr (mod, T, " julia.gpu.state_getter" )
832+ kernel_state_value (T:: Type ) = call_custom_intrinsic (T, " julia.gpu.state_getter" , " state" )
833+
831834# convert kernel state argument from pass-by-value to pass-by-reference
832835#
833836# the kernel state argument is always passed by value to avoid codegen issues with byval.
@@ -923,3 +926,160 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
923926 return new_f
924927 end
925928end
929+
930+ function add_input_arguments! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
931+ entry:: LLVM.Function , kernel_intrinsics:: Dict )
932+ entry_fn = LLVM. name (entry)
933+
934+ # figure out which intrinsics are used and need to be added as arguments
935+ used_intrinsics = filter (keys (kernel_intrinsics)) do intr_fn
936+ haskey (functions (mod), intr_fn)
937+ end |> collect
938+ nargs = length (used_intrinsics)
939+
940+ # determine which functions need these arguments
941+ worklist = Set {LLVM.Function} ([entry])
942+ for intr_fn in used_intrinsics
943+ push! (worklist, functions (mod)[intr_fn])
944+ end
945+ worklist_length = 0
946+ while worklist_length != length (worklist)
947+ # iteratively discover functions that use an intrinsic or any function calling it
948+ worklist_length = length (worklist)
949+ additions = LLVM. Function[]
950+ for f in worklist, use in uses (f)
951+ inst = user (use):: Instruction
952+ bb = LLVM. parent (inst)
953+ new_f = LLVM. parent (bb)
954+ in (new_f, worklist) || push! (additions, new_f)
955+ end
956+ for f in additions
957+ push! (worklist, f)
958+ end
959+ end
960+ for intr_fn in used_intrinsics
961+ delete! (worklist, functions (mod)[intr_fn])
962+ end
963+
964+ # add the arguments
965+ # NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
966+ workmap = Dict {LLVM.Function, LLVM.Function} ()
967+ for f in worklist
968+ fn = LLVM. name (f)
969+ ft = function_type (f)
970+ LLVM. name! (f, fn * " .orig" )
971+ # create a new function
972+ new_param_types = LLVMType[parameters (ft)... ]
973+
974+ for intr_fn in used_intrinsics
975+ llvm_typ = convert (LLVMType, kernel_intrinsics[intr_fn]. typ)
976+ push! (new_param_types, llvm_typ)
977+ end
978+ new_ft = LLVM. FunctionType (return_type (ft), new_param_types)
979+ new_f = LLVM. Function (mod, fn, new_ft)
980+ linkage! (new_f, linkage (f))
981+ for (arg, new_arg) in zip (parameters (f), parameters (new_f))
982+ LLVM. name! (new_arg, LLVM. name (arg))
983+ end
984+ for (intr_fn, new_arg) in zip (used_intrinsics, parameters (new_f)[end - nargs+ 1 : end ])
985+ LLVM. name! (new_arg, kernel_intrinsics[intr_fn]. name)
986+ end
987+
988+ workmap[f] = new_f
989+ end
990+
991+ # clone and rewrite the function bodies.
992+ # we don't need to rewrite much as the arguments are added last.
993+ for (f, new_f) in workmap
994+ # map the arguments
995+ value_map = Dict {LLVM.Value, LLVM.Value} ()
996+ for (param, new_param) in zip (parameters (f), parameters (new_f))
997+ LLVM. name! (new_param, LLVM. name (param))
998+ value_map[param] = new_param
999+ end
1000+
1001+ value_map[f] = new_f
1002+ clone_into! (new_f, f; value_map,
1003+ changes= LLVM. API. LLVMCloneFunctionChangeTypeLocalChangesOnly)
1004+
1005+ # we can't remove this function yet, as we might still need to rewrite any called,
1006+ # but remove the IR already
1007+ empty! (f)
1008+ end
1009+
1010+ # drop unused constants that may be referring to the old functions
1011+ # XXX : can we do this differently?
1012+ for f in worklist
1013+ prune_constexpr_uses! (f)
1014+ end
1015+
1016+ # update other uses of the old function, modifying call sites to pass the arguments
1017+ function rewrite_uses! (f, new_f)
1018+ # update uses
1019+ @dispose builder= IRBuilder () begin
1020+ for use in uses (f)
1021+ val = user (use)
1022+ if val isa LLVM. CallInst || val isa LLVM. InvokeInst || val isa LLVM. CallBrInst
1023+ callee_f = LLVM. parent (LLVM. parent (val))
1024+ # forward the arguments
1025+ position! (builder, val)
1026+ new_val = if val isa LLVM. CallInst
1027+ call! (builder, function_type (new_f), new_f,
1028+ [arguments (val)... , parameters (callee_f)[end - nargs+ 1 : end ]. .. ],
1029+ operand_bundles (val))
1030+ else
1031+ # TODO : invoke and callbr
1032+ error (" Rewrite of $(typeof (val)) -based calls is not implemented: $val " )
1033+ end
1034+ callconv! (new_val, callconv (val))
1035+
1036+ replace_uses! (val, new_val)
1037+ @assert isempty (uses (val))
1038+ erase! (val)
1039+ elseif val isa LLVM. ConstantExpr && opcode (val) == LLVM. API. LLVMBitCast
1040+ # XXX : why isn't this caught by the value materializer above?
1041+ target = operands (val)[1 ]
1042+ @assert target == f
1043+ new_val = LLVM. const_bitcast (new_f, value_type (val))
1044+ rewrite_uses! (val, new_val)
1045+ # we can't simply replace this constant expression, as it may be used
1046+ # as a call, taking arguments (so we need to rewrite it to pass the input arguments)
1047+
1048+ # drop the old constant if it is unused
1049+ # XXX : can we do this differently?
1050+ if isempty (uses (val))
1051+ LLVM. unsafe_destroy! (val)
1052+ end
1053+ else
1054+ error (" Cannot rewrite unknown use of function: $val " )
1055+ end
1056+ end
1057+ end
1058+ end
1059+ for (f, new_f) in workmap
1060+ rewrite_uses! (f, new_f)
1061+ @assert isempty (uses (f))
1062+ erase! (f)
1063+ end
1064+
1065+ # replace uses of the intrinsics with references to the input arguments
1066+ for (i, intr_fn) in enumerate (used_intrinsics)
1067+ intr = functions (mod)[intr_fn]
1068+ for use in uses (intr)
1069+ val = user (use)
1070+ callee_f = LLVM. parent (LLVM. parent (val))
1071+ if val isa LLVM. CallInst || val isa LLVM. InvokeInst || val isa LLVM. CallBrInst
1072+ replace_uses! (val, parameters (callee_f)[end - nargs+ i])
1073+ else
1074+ error (" Cannot rewrite unknown use of function: $val " )
1075+ end
1076+
1077+ @assert isempty (uses (val))
1078+ erase! (val)
1079+ end
1080+ @assert isempty (uses (intr))
1081+ erase! (intr)
1082+ end
1083+
1084+ return functions (mod)[entry_fn]
1085+ end
0 commit comments