@@ -788,44 +788,47 @@ function cleanup_kernel_state!(mod::LLVM.Module)
788788end
789789CleanupKernelStatePass () = NewPMModulePass (" CleanupKernelStatePass" , cleanup_kernel_state!)
790790
791- function kernel_state_intr (mod:: LLVM.Module , T_state )
792- state_intr = if haskey (functions (mod), " julia.gpu.state_getter " )
793- functions (mod)[" julia.gpu.state_getter " ]
791+ function custom_intr (mod:: LLVM.Module , T :: LLVMType , name :: String )
792+ custom_intr = if haskey (functions (mod), name )
793+ functions (mod)[name ]
794794 else
795- LLVM. Function (mod, " julia.gpu.state_getter " , LLVM. FunctionType (T_state ))
795+ LLVM. Function (mod, name , LLVM. FunctionType (T ))
796796 end
797- push! (function_attributes (state_intr ), EnumAttribute (" readnone" , 0 ))
797+ push! (function_attributes (custom_intr ), EnumAttribute (" readnone" , 0 ))
798798
799- return state_intr
799+ return custom_intr
800800end
801801
802802# run-time equivalent
803- function kernel_state_value (state )
803+ function call_custom_intrinsic (T :: Type , name :: String , call_name :: String )
804804 @dispose ctx= Context () begin
805- T_state = convert (LLVMType, state )
805+ T_llvm = convert (LLVMType, T )
806806
807807 # create function
808- llvm_f, _ = create_function (T_state )
808+ llvm_f, _ = create_function (T_llvm )
809809 mod = LLVM. parent (llvm_f)
810810
811811 # get intrinsic
812- state_intr = kernel_state_intr (mod, T_state )
813- state_intr_ft = function_type (state_intr )
812+ _custom_intr = custom_intr (mod, T_llvm, name )
813+ custom_intr_ft = function_type (_custom_intr )
814814
815815 # generate IR
816816 @dispose builder= IRBuilder () begin
817817 entry = BasicBlock (llvm_f, " entry" )
818818 position! (builder, entry)
819819
820- val = call! (builder, state_intr_ft, state_intr , Value[], " state " )
820+ val = call! (builder, custom_intr_ft, _custom_intr , Value[], call_name )
821821
822822 ret! (builder, val)
823823 end
824824
825- call_function (llvm_f, state )
825+ call_function (llvm_f, T )
826826 end
827827end
828828
829+ kernel_state_intr (mod:: LLVM.Module , T:: LLVMType ) = custom_intr (mod, T, " julia.gpu.state_getter" )
830+ kernel_state_value (T:: Type ) = call_custom_intrinsic (T, " julia.gpu.state_getter" , " state" )
831+
829832# convert kernel state argument from pass-by-value to pass-by-reference
830833#
831834# the kernel state argument is always passed by value to avoid codegen issues with byval.
@@ -921,3 +924,160 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
921924 return new_f
922925 end
923926end
927+
928+ function add_input_arguments! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
929+ entry:: LLVM.Function , kernel_intrinsics:: Dict )
930+ entry_fn = LLVM. name (entry)
931+
932+ # figure out which intrinsics are used and need to be added as arguments
933+ used_intrinsics = filter (keys (kernel_intrinsics)) do intr_fn
934+ haskey (functions (mod), intr_fn)
935+ end |> collect
936+ nargs = length (used_intrinsics)
937+
938+ # determine which functions need these arguments
939+ worklist = Set {LLVM.Function} ([entry])
940+ for intr_fn in used_intrinsics
941+ push! (worklist, functions (mod)[intr_fn])
942+ end
943+ worklist_length = 0
944+ while worklist_length != length (worklist)
945+ # iteratively discover functions that use an intrinsic or any function calling it
946+ worklist_length = length (worklist)
947+ additions = LLVM. Function[]
948+ for f in worklist, use in uses (f)
949+ inst = user (use):: Instruction
950+ bb = LLVM. parent (inst)
951+ new_f = LLVM. parent (bb)
952+ in (new_f, worklist) || push! (additions, new_f)
953+ end
954+ for f in additions
955+ push! (worklist, f)
956+ end
957+ end
958+ for intr_fn in used_intrinsics
959+ delete! (worklist, functions (mod)[intr_fn])
960+ end
961+
962+ # add the arguments
963+ # NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
964+ workmap = Dict {LLVM.Function, LLVM.Function} ()
965+ for f in worklist
966+ fn = LLVM. name (f)
967+ ft = function_type (f)
968+ LLVM. name! (f, fn * " .orig" )
969+ # create a new function
970+ new_param_types = LLVMType[parameters (ft)... ]
971+
972+ for intr_fn in used_intrinsics
973+ llvm_typ = convert (LLVMType, kernel_intrinsics[intr_fn]. typ)
974+ push! (new_param_types, llvm_typ)
975+ end
976+ new_ft = LLVM. FunctionType (return_type (ft), new_param_types)
977+ new_f = LLVM. Function (mod, fn, new_ft)
978+ linkage! (new_f, linkage (f))
979+ for (arg, new_arg) in zip (parameters (f), parameters (new_f))
980+ LLVM. name! (new_arg, LLVM. name (arg))
981+ end
982+ for (intr_fn, new_arg) in zip (used_intrinsics, parameters (new_f)[end - nargs+ 1 : end ])
983+ LLVM. name! (new_arg, kernel_intrinsics[intr_fn]. name)
984+ end
985+
986+ workmap[f] = new_f
987+ end
988+
989+ # clone and rewrite the function bodies.
990+ # we don't need to rewrite much as the arguments are added last.
991+ for (f, new_f) in workmap
992+ # map the arguments
993+ value_map = Dict {LLVM.Value, LLVM.Value} ()
994+ for (param, new_param) in zip (parameters (f), parameters (new_f))
995+ LLVM. name! (new_param, LLVM. name (param))
996+ value_map[param] = new_param
997+ end
998+
999+ value_map[f] = new_f
1000+ clone_into! (new_f, f; value_map,
1001+ changes= LLVM. API. LLVMCloneFunctionChangeTypeLocalChangesOnly)
1002+
1003+ # we can't remove this function yet, as we might still need to rewrite any called,
1004+ # but remove the IR already
1005+ empty! (f)
1006+ end
1007+
1008+ # drop unused constants that may be referring to the old functions
1009+ # XXX : can we do this differently?
1010+ for f in worklist
1011+ prune_constexpr_uses! (f)
1012+ end
1013+
1014+ # update other uses of the old function, modifying call sites to pass the arguments
1015+ function rewrite_uses! (f, new_f)
1016+ # update uses
1017+ @dispose builder= IRBuilder () begin
1018+ for use in uses (f)
1019+ val = user (use)
1020+ if val isa LLVM. CallInst || val isa LLVM. InvokeInst || val isa LLVM. CallBrInst
1021+ callee_f = LLVM. parent (LLVM. parent (val))
1022+ # forward the arguments
1023+ position! (builder, val)
1024+ new_val = if val isa LLVM. CallInst
1025+ call! (builder, function_type (new_f), new_f,
1026+ [arguments (val)... , parameters (callee_f)[end - nargs+ 1 : end ]. .. ],
1027+ operand_bundles (val))
1028+ else
1029+ # TODO : invoke and callbr
1030+ error (" Rewrite of $(typeof (val)) -based calls is not implemented: $val " )
1031+ end
1032+ callconv! (new_val, callconv (val))
1033+
1034+ replace_uses! (val, new_val)
1035+ @assert isempty (uses (val))
1036+ erase! (val)
1037+ elseif val isa LLVM. ConstantExpr && opcode (val) == LLVM. API. LLVMBitCast
1038+ # XXX : why isn't this caught by the value materializer above?
1039+ target = operands (val)[1 ]
1040+ @assert target == f
1041+ new_val = LLVM. const_bitcast (new_f, value_type (val))
1042+ rewrite_uses! (val, new_val)
1043+ # we can't simply replace this constant expression, as it may be used
1044+ # as a call, taking arguments (so we need to rewrite it to pass the input arguments)
1045+
1046+ # drop the old constant if it is unused
1047+ # XXX : can we do this differently?
1048+ if isempty (uses (val))
1049+ LLVM. unsafe_destroy! (val)
1050+ end
1051+ else
1052+ error (" Cannot rewrite unknown use of function: $val " )
1053+ end
1054+ end
1055+ end
1056+ end
1057+ for (f, new_f) in workmap
1058+ rewrite_uses! (f, new_f)
1059+ @assert isempty (uses (f))
1060+ erase! (f)
1061+ end
1062+
1063+ # replace uses of the intrinsics with references to the input arguments
1064+ for (i, intr_fn) in enumerate (used_intrinsics)
1065+ intr = functions (mod)[intr_fn]
1066+ for use in uses (intr)
1067+ val = user (use)
1068+ callee_f = LLVM. parent (LLVM. parent (val))
1069+ if val isa LLVM. CallInst || val isa LLVM. InvokeInst || val isa LLVM. CallBrInst
1070+ replace_uses! (val, parameters (callee_f)[end - nargs+ i])
1071+ else
1072+ error (" Cannot rewrite unknown use of function: $val " )
1073+ end
1074+
1075+ @assert isempty (uses (val))
1076+ erase! (val)
1077+ end
1078+ @assert isempty (uses (intr))
1079+ erase! (intr)
1080+ end
1081+
1082+ return functions (mod)[entry_fn]
1083+ end
0 commit comments