Skip to content

Commit e3894ce

Browse files
committed
add_input_arguments! for other backends
Allows other backends to pass additional hidden arguments that can be accessed through intrinsics. Required for OpenCL device-side RNG support, where additional shared memory must be passed as arguments to the kernel. Replaces #717
1 parent e9ad136 commit e3894ce

File tree

2 files changed

+174
-174
lines changed

2 files changed

+174
-174
lines changed

src/irgen.jl

Lines changed: 173 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -788,44 +788,47 @@ function cleanup_kernel_state!(mod::LLVM.Module)
788788
end
789789
CleanupKernelStatePass() = NewPMModulePass("CleanupKernelStatePass", cleanup_kernel_state!)
790790

791-
function kernel_state_intr(mod::LLVM.Module, T_state)
792-
state_intr = if haskey(functions(mod), "julia.gpu.state_getter")
793-
functions(mod)["julia.gpu.state_getter"]
791+
function custom_intr(mod::LLVM.Module, T::LLVMType, name::String)
792+
custom_intr = if haskey(functions(mod), name)
793+
functions(mod)[name]
794794
else
795-
LLVM.Function(mod, "julia.gpu.state_getter", LLVM.FunctionType(T_state))
795+
LLVM.Function(mod, name, LLVM.FunctionType(T))
796796
end
797-
push!(function_attributes(state_intr), EnumAttribute("readnone", 0))
797+
push!(function_attributes(custom_intr), EnumAttribute("readnone", 0))
798798

799-
return state_intr
799+
return custom_intr
800800
end
801801

802802
# run-time equivalent
803-
function kernel_state_value(state)
803+
function call_custom_intrinsic(T::Type, name::String, call_name::String)
804804
@dispose ctx=Context() begin
805-
T_state = convert(LLVMType, state)
805+
T_llvm = convert(LLVMType, T)
806806

807807
# create function
808-
llvm_f, _ = create_function(T_state)
808+
llvm_f, _ = create_function(T_llvm)
809809
mod = LLVM.parent(llvm_f)
810810

811811
# get intrinsic
812-
state_intr = kernel_state_intr(mod, T_state)
813-
state_intr_ft = function_type(state_intr)
812+
_custom_intr = custom_intr(mod, T_llvm, name)
813+
custom_intr_ft = function_type(_custom_intr)
814814

815815
# generate IR
816816
@dispose builder=IRBuilder() begin
817817
entry = BasicBlock(llvm_f, "entry")
818818
position!(builder, entry)
819819

820-
val = call!(builder, state_intr_ft, state_intr, Value[], "state")
820+
val = call!(builder, custom_intr_ft, _custom_intr, Value[], call_name)
821821

822822
ret!(builder, val)
823823
end
824824

825-
call_function(llvm_f, state)
825+
call_function(llvm_f, T)
826826
end
827827
end
828828

829+
kernel_state_intr(mod::LLVM.Module, T::LLVMType) = custom_intr(mod, T, "julia.gpu.state_getter")
830+
kernel_state_value(T::Type) = call_custom_intrinsic(T, "julia.gpu.state_getter", "state")
831+
829832
# convert kernel state argument from pass-by-value to pass-by-reference
830833
#
831834
# the kernel state argument is always passed by value to avoid codegen issues with byval.
@@ -921,3 +924,160 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
921924
return new_f
922925
end
923926
end
927+
928+
function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
929+
entry::LLVM.Function, kernel_intrinsics::Dict)
930+
entry_fn = LLVM.name(entry)
931+
932+
# figure out which intrinsics are used and need to be added as arguments
933+
used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn
934+
haskey(functions(mod), intr_fn)
935+
end |> collect
936+
nargs = length(used_intrinsics)
937+
938+
# determine which functions need these arguments
939+
worklist = Set{LLVM.Function}([entry])
940+
for intr_fn in used_intrinsics
941+
push!(worklist, functions(mod)[intr_fn])
942+
end
943+
worklist_length = 0
944+
while worklist_length != length(worklist)
945+
# iteratively discover functions that use an intrinsic or any function calling it
946+
worklist_length = length(worklist)
947+
additions = LLVM.Function[]
948+
for f in worklist, use in uses(f)
949+
inst = user(use)::Instruction
950+
bb = LLVM.parent(inst)
951+
new_f = LLVM.parent(bb)
952+
in(new_f, worklist) || push!(additions, new_f)
953+
end
954+
for f in additions
955+
push!(worklist, f)
956+
end
957+
end
958+
for intr_fn in used_intrinsics
959+
delete!(worklist, functions(mod)[intr_fn])
960+
end
961+
962+
# add the arguments
963+
# NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
964+
workmap = Dict{LLVM.Function, LLVM.Function}()
965+
for f in worklist
966+
fn = LLVM.name(f)
967+
ft = function_type(f)
968+
LLVM.name!(f, fn * ".orig")
969+
# create a new function
970+
new_param_types = LLVMType[parameters(ft)...]
971+
972+
for intr_fn in used_intrinsics
973+
llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ)
974+
push!(new_param_types, llvm_typ)
975+
end
976+
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
977+
new_f = LLVM.Function(mod, fn, new_ft)
978+
linkage!(new_f, linkage(f))
979+
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
980+
LLVM.name!(new_arg, LLVM.name(arg))
981+
end
982+
for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
983+
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
984+
end
985+
986+
workmap[f] = new_f
987+
end
988+
989+
# clone and rewrite the function bodies.
990+
# we don't need to rewrite much as the arguments are added last.
991+
for (f, new_f) in workmap
992+
# map the arguments
993+
value_map = Dict{LLVM.Value, LLVM.Value}()
994+
for (param, new_param) in zip(parameters(f), parameters(new_f))
995+
LLVM.name!(new_param, LLVM.name(param))
996+
value_map[param] = new_param
997+
end
998+
999+
value_map[f] = new_f
1000+
clone_into!(new_f, f; value_map,
1001+
changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly)
1002+
1003+
# we can't remove this function yet, as we might still need to rewrite any called,
1004+
# but remove the IR already
1005+
empty!(f)
1006+
end
1007+
1008+
# drop unused constants that may be referring to the old functions
1009+
# XXX: can we do this differently?
1010+
for f in worklist
1011+
prune_constexpr_uses!(f)
1012+
end
1013+
1014+
# update other uses of the old function, modifying call sites to pass the arguments
1015+
function rewrite_uses!(f, new_f)
1016+
# update uses
1017+
@dispose builder=IRBuilder() begin
1018+
for use in uses(f)
1019+
val = user(use)
1020+
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
1021+
callee_f = LLVM.parent(LLVM.parent(val))
1022+
# forward the arguments
1023+
position!(builder, val)
1024+
new_val = if val isa LLVM.CallInst
1025+
call!(builder, function_type(new_f), new_f,
1026+
[arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
1027+
operand_bundles(val))
1028+
else
1029+
# TODO: invoke and callbr
1030+
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
1031+
end
1032+
callconv!(new_val, callconv(val))
1033+
1034+
replace_uses!(val, new_val)
1035+
@assert isempty(uses(val))
1036+
erase!(val)
1037+
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
1038+
# XXX: why isn't this caught by the value materializer above?
1039+
target = operands(val)[1]
1040+
@assert target == f
1041+
new_val = LLVM.const_bitcast(new_f, value_type(val))
1042+
rewrite_uses!(val, new_val)
1043+
# we can't simply replace this constant expression, as it may be used
1044+
# as a call, taking arguments (so we need to rewrite it to pass the input arguments)
1045+
1046+
# drop the old constant if it is unused
1047+
# XXX: can we do this differently?
1048+
if isempty(uses(val))
1049+
LLVM.unsafe_destroy!(val)
1050+
end
1051+
else
1052+
error("Cannot rewrite unknown use of function: $val")
1053+
end
1054+
end
1055+
end
1056+
end
1057+
for (f, new_f) in workmap
1058+
rewrite_uses!(f, new_f)
1059+
@assert isempty(uses(f))
1060+
erase!(f)
1061+
end
1062+
1063+
# replace uses of the intrinsics with references to the input arguments
1064+
for (i, intr_fn) in enumerate(used_intrinsics)
1065+
intr = functions(mod)[intr_fn]
1066+
for use in uses(intr)
1067+
val = user(use)
1068+
callee_f = LLVM.parent(LLVM.parent(val))
1069+
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
1070+
replace_uses!(val, parameters(callee_f)[end-nargs+i])
1071+
else
1072+
error("Cannot rewrite unknown use of function: $val")
1073+
end
1074+
1075+
@assert isempty(uses(val))
1076+
erase!(val)
1077+
end
1078+
@assert isempty(uses(intr))
1079+
erase!(intr)
1080+
end
1081+
1082+
return functions(mod)[entry_fn]
1083+
end

src/metal.jl

Lines changed: 1 addition & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mo
5353
# update calling conventions
5454
if job.config.kernel
5555
entry = pass_by_reference!(job, mod, entry)
56-
57-
add_input_arguments!(job, mod, entry)
58-
entry = LLVM.functions(mod)[entry_fn]
56+
entry = add_input_arguments!(job, mod, entry, kernel_intrinsics)
5957
end
6058

6159
# emit the AIR and Metal version numbers as constants in the module. this makes it
@@ -553,164 +551,6 @@ function argument_type_name(typ)
553551
end
554552
end
555553

556-
function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
557-
entry::LLVM.Function)
558-
entry_fn = LLVM.name(entry)
559-
560-
# figure out which intrinsics are used and need to be added as arguments
561-
used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn
562-
haskey(functions(mod), intr_fn)
563-
end |> collect
564-
nargs = length(used_intrinsics)
565-
566-
# determine which functions need these arguments
567-
worklist = Set{LLVM.Function}([entry])
568-
for intr_fn in used_intrinsics
569-
push!(worklist, functions(mod)[intr_fn])
570-
end
571-
worklist_length = 0
572-
while worklist_length != length(worklist)
573-
# iteratively discover functions that use an intrinsic or any function calling it
574-
worklist_length = length(worklist)
575-
additions = LLVM.Function[]
576-
for f in worklist, use in uses(f)
577-
inst = user(use)::Instruction
578-
bb = LLVM.parent(inst)
579-
new_f = LLVM.parent(bb)
580-
in(new_f, worklist) || push!(additions, new_f)
581-
end
582-
for f in additions
583-
push!(worklist, f)
584-
end
585-
end
586-
for intr_fn in used_intrinsics
587-
delete!(worklist, functions(mod)[intr_fn])
588-
end
589-
590-
# add the arguments
591-
# NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
592-
workmap = Dict{LLVM.Function, LLVM.Function}()
593-
for f in worklist
594-
fn = LLVM.name(f)
595-
ft = function_type(f)
596-
LLVM.name!(f, fn * ".orig")
597-
# create a new function
598-
new_param_types = LLVMType[parameters(ft)...]
599-
600-
for intr_fn in used_intrinsics
601-
llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ)
602-
push!(new_param_types, llvm_typ)
603-
end
604-
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
605-
new_f = LLVM.Function(mod, fn, new_ft)
606-
linkage!(new_f, linkage(f))
607-
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
608-
LLVM.name!(new_arg, LLVM.name(arg))
609-
end
610-
for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
611-
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
612-
end
613-
614-
workmap[f] = new_f
615-
end
616-
617-
# clone and rewrite the function bodies.
618-
# we don't need to rewrite much as the arguments are added last.
619-
for (f, new_f) in workmap
620-
# map the arguments
621-
value_map = Dict{LLVM.Value, LLVM.Value}()
622-
for (param, new_param) in zip(parameters(f), parameters(new_f))
623-
LLVM.name!(new_param, LLVM.name(param))
624-
value_map[param] = new_param
625-
end
626-
627-
value_map[f] = new_f
628-
clone_into!(new_f, f; value_map,
629-
changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly)
630-
631-
# we can't remove this function yet, as we might still need to rewrite any called,
632-
# but remove the IR already
633-
empty!(f)
634-
end
635-
636-
# drop unused constants that may be referring to the old functions
637-
# XXX: can we do this differently?
638-
for f in worklist
639-
prune_constexpr_uses!(f)
640-
end
641-
642-
# update other uses of the old function, modifying call sites to pass the arguments
643-
function rewrite_uses!(f, new_f)
644-
# update uses
645-
@dispose builder=IRBuilder() begin
646-
for use in uses(f)
647-
val = user(use)
648-
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
649-
callee_f = LLVM.parent(LLVM.parent(val))
650-
# forward the arguments
651-
position!(builder, val)
652-
new_val = if val isa LLVM.CallInst
653-
call!(builder, function_type(new_f), new_f,
654-
[arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
655-
operand_bundles(val))
656-
else
657-
# TODO: invoke and callbr
658-
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
659-
end
660-
callconv!(new_val, callconv(val))
661-
662-
replace_uses!(val, new_val)
663-
@assert isempty(uses(val))
664-
erase!(val)
665-
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
666-
# XXX: why isn't this caught by the value materializer above?
667-
target = operands(val)[1]
668-
@assert target == f
669-
new_val = LLVM.const_bitcast(new_f, value_type(val))
670-
rewrite_uses!(val, new_val)
671-
# we can't simply replace this constant expression, as it may be used
672-
# as a call, taking arguments (so we need to rewrite it to pass the input arguments)
673-
674-
# drop the old constant if it is unused
675-
# XXX: can we do this differently?
676-
if isempty(uses(val))
677-
LLVM.unsafe_destroy!(val)
678-
end
679-
else
680-
error("Cannot rewrite unknown use of function: $val")
681-
end
682-
end
683-
end
684-
end
685-
for (f, new_f) in workmap
686-
rewrite_uses!(f, new_f)
687-
@assert isempty(uses(f))
688-
erase!(f)
689-
end
690-
691-
# replace uses of the intrinsics with references to the input arguments
692-
for (i, intr_fn) in enumerate(used_intrinsics)
693-
intr = functions(mod)[intr_fn]
694-
for use in uses(intr)
695-
val = user(use)
696-
callee_f = LLVM.parent(LLVM.parent(val))
697-
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
698-
replace_uses!(val, parameters(callee_f)[end-nargs+i])
699-
else
700-
error("Cannot rewrite unknown use of function: $val")
701-
end
702-
703-
@assert isempty(uses(val))
704-
erase!(val)
705-
end
706-
@assert isempty(uses(intr))
707-
erase!(intr)
708-
end
709-
710-
return
711-
end
712-
713-
714554
# argument metadata generation
715555
#
716556
# module metadata is used to identify buffers that are passed as kernel arguments.

0 commit comments

Comments
 (0)