Skip to content

Commit 07a039a

Browse files
committed
add_input_arguments! for other backends
Allows other backends to pass additional hidden arguments that can be accessed through intrinsics. Required for OpenCL device-side RNG support, where additional shared memory must be passed as arguments to the kernel. Replaces #717
1 parent dce63fc commit 07a039a

File tree

2 files changed

+174
-174
lines changed

2 files changed

+174
-174
lines changed

src/irgen.jl

Lines changed: 173 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -790,44 +790,47 @@ function cleanup_kernel_state!(mod::LLVM.Module)
790790
end
791791
CleanupKernelStatePass() = NewPMModulePass("CleanupKernelStatePass", cleanup_kernel_state!)
792792

793-
function kernel_state_intr(mod::LLVM.Module, T_state)
794-
state_intr = if haskey(functions(mod), "julia.gpu.state_getter")
795-
functions(mod)["julia.gpu.state_getter"]
793+
function custom_intr(mod::LLVM.Module, T::LLVMType, name::String)
794+
custom_intr = if haskey(functions(mod), name)
795+
functions(mod)[name]
796796
else
797-
LLVM.Function(mod, "julia.gpu.state_getter", LLVM.FunctionType(T_state))
797+
LLVM.Function(mod, name, LLVM.FunctionType(T))
798798
end
799-
push!(function_attributes(state_intr), EnumAttribute("readnone", 0))
799+
push!(function_attributes(custom_intr), EnumAttribute("readnone", 0))
800800

801-
return state_intr
801+
return custom_intr
802802
end
803803

804804
# run-time equivalent
805-
function kernel_state_value(state)
805+
function call_custom_intrinsic(T::Type, name::String, call_name::String)
806806
@dispose ctx=Context() begin
807-
T_state = convert(LLVMType, state)
807+
T_llvm = convert(LLVMType, T)
808808

809809
# create function
810-
llvm_f, _ = create_function(T_state)
810+
llvm_f, _ = create_function(T_llvm)
811811
mod = LLVM.parent(llvm_f)
812812

813813
# get intrinsic
814-
state_intr = kernel_state_intr(mod, T_state)
815-
state_intr_ft = function_type(state_intr)
814+
_custom_intr = custom_intr(mod, T_llvm, name)
815+
custom_intr_ft = function_type(_custom_intr)
816816

817817
# generate IR
818818
@dispose builder=IRBuilder() begin
819819
entry = BasicBlock(llvm_f, "entry")
820820
position!(builder, entry)
821821

822-
val = call!(builder, state_intr_ft, state_intr, Value[], "state")
822+
val = call!(builder, custom_intr_ft, _custom_intr, Value[], call_name)
823823

824824
ret!(builder, val)
825825
end
826826

827-
call_function(llvm_f, state)
827+
call_function(llvm_f, T)
828828
end
829829
end
830830

831+
kernel_state_intr(mod::LLVM.Module, T::LLVMType) = custom_intr(mod, T, "julia.gpu.state_getter")
832+
kernel_state_value(T::Type) = call_custom_intrinsic(T, "julia.gpu.state_getter", "state")
833+
831834
# convert kernel state argument from pass-by-value to pass-by-reference
832835
#
833836
# the kernel state argument is always passed by value to avoid codegen issues with byval.
@@ -923,3 +926,160 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
923926
return new_f
924927
end
925928
end
929+
930+
function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
931+
entry::LLVM.Function, kernel_intrinsics::Dict)
932+
entry_fn = LLVM.name(entry)
933+
934+
# figure out which intrinsics are used and need to be added as arguments
935+
used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn
936+
haskey(functions(mod), intr_fn)
937+
end |> collect
938+
nargs = length(used_intrinsics)
939+
940+
# determine which functions need these arguments
941+
worklist = Set{LLVM.Function}([entry])
942+
for intr_fn in used_intrinsics
943+
push!(worklist, functions(mod)[intr_fn])
944+
end
945+
worklist_length = 0
946+
while worklist_length != length(worklist)
947+
# iteratively discover functions that use an intrinsic or any function calling it
948+
worklist_length = length(worklist)
949+
additions = LLVM.Function[]
950+
for f in worklist, use in uses(f)
951+
inst = user(use)::Instruction
952+
bb = LLVM.parent(inst)
953+
new_f = LLVM.parent(bb)
954+
in(new_f, worklist) || push!(additions, new_f)
955+
end
956+
for f in additions
957+
push!(worklist, f)
958+
end
959+
end
960+
for intr_fn in used_intrinsics
961+
delete!(worklist, functions(mod)[intr_fn])
962+
end
963+
964+
# add the arguments
965+
# NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
966+
workmap = Dict{LLVM.Function, LLVM.Function}()
967+
for f in worklist
968+
fn = LLVM.name(f)
969+
ft = function_type(f)
970+
LLVM.name!(f, fn * ".orig")
971+
# create a new function
972+
new_param_types = LLVMType[parameters(ft)...]
973+
974+
for intr_fn in used_intrinsics
975+
llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ)
976+
push!(new_param_types, llvm_typ)
977+
end
978+
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
979+
new_f = LLVM.Function(mod, fn, new_ft)
980+
linkage!(new_f, linkage(f))
981+
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
982+
LLVM.name!(new_arg, LLVM.name(arg))
983+
end
984+
for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
985+
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
986+
end
987+
988+
workmap[f] = new_f
989+
end
990+
991+
# clone and rewrite the function bodies.
992+
# we don't need to rewrite much as the arguments are added last.
993+
for (f, new_f) in workmap
994+
# map the arguments
995+
value_map = Dict{LLVM.Value, LLVM.Value}()
996+
for (param, new_param) in zip(parameters(f), parameters(new_f))
997+
LLVM.name!(new_param, LLVM.name(param))
998+
value_map[param] = new_param
999+
end
1000+
1001+
value_map[f] = new_f
1002+
clone_into!(new_f, f; value_map,
1003+
changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly)
1004+
1005+
# we can't remove this function yet, as we might still need to rewrite any called,
1006+
# but remove the IR already
1007+
empty!(f)
1008+
end
1009+
1010+
# drop unused constants that may be referring to the old functions
1011+
# XXX: can we do this differently?
1012+
for f in worklist
1013+
prune_constexpr_uses!(f)
1014+
end
1015+
1016+
# update other uses of the old function, modifying call sites to pass the arguments
1017+
function rewrite_uses!(f, new_f)
1018+
# update uses
1019+
@dispose builder=IRBuilder() begin
1020+
for use in uses(f)
1021+
val = user(use)
1022+
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
1023+
callee_f = LLVM.parent(LLVM.parent(val))
1024+
# forward the arguments
1025+
position!(builder, val)
1026+
new_val = if val isa LLVM.CallInst
1027+
call!(builder, function_type(new_f), new_f,
1028+
[arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
1029+
operand_bundles(val))
1030+
else
1031+
# TODO: invoke and callbr
1032+
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
1033+
end
1034+
callconv!(new_val, callconv(val))
1035+
1036+
replace_uses!(val, new_val)
1037+
@assert isempty(uses(val))
1038+
erase!(val)
1039+
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
1040+
# XXX: why isn't this caught by the value materializer above?
1041+
target = operands(val)[1]
1042+
@assert target == f
1043+
new_val = LLVM.const_bitcast(new_f, value_type(val))
1044+
rewrite_uses!(val, new_val)
1045+
# we can't simply replace this constant expression, as it may be used
1046+
# as a call, taking arguments (so we need to rewrite it to pass the input arguments)
1047+
1048+
# drop the old constant if it is unused
1049+
# XXX: can we do this differently?
1050+
if isempty(uses(val))
1051+
LLVM.unsafe_destroy!(val)
1052+
end
1053+
else
1054+
error("Cannot rewrite unknown use of function: $val")
1055+
end
1056+
end
1057+
end
1058+
end
1059+
for (f, new_f) in workmap
1060+
rewrite_uses!(f, new_f)
1061+
@assert isempty(uses(f))
1062+
erase!(f)
1063+
end
1064+
1065+
# replace uses of the intrinsics with references to the input arguments
1066+
for (i, intr_fn) in enumerate(used_intrinsics)
1067+
intr = functions(mod)[intr_fn]
1068+
for use in uses(intr)
1069+
val = user(use)
1070+
callee_f = LLVM.parent(LLVM.parent(val))
1071+
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
1072+
replace_uses!(val, parameters(callee_f)[end-nargs+i])
1073+
else
1074+
error("Cannot rewrite unknown use of function: $val")
1075+
end
1076+
1077+
@assert isempty(uses(val))
1078+
erase!(val)
1079+
end
1080+
@assert isempty(uses(intr))
1081+
erase!(intr)
1082+
end
1083+
1084+
return functions(mod)[entry_fn]
1085+
end

src/metal.jl

Lines changed: 1 addition & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mo
5353
# update calling conventions
5454
if job.config.kernel
5555
entry = pass_by_reference!(job, mod, entry)
56-
57-
add_input_arguments!(job, mod, entry)
58-
entry = LLVM.functions(mod)[entry_fn]
56+
entry = add_input_arguments!(job, mod, entry, kernel_intrinsics)
5957
end
6058

6159
# emit the AIR and Metal version numbers as constants in the module. this makes it
@@ -553,164 +551,6 @@ function argument_type_name(typ)
553551
end
554552
end
555553

556-
function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
557-
entry::LLVM.Function)
558-
entry_fn = LLVM.name(entry)
559-
560-
# figure out which intrinsics are used and need to be added as arguments
561-
used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn
562-
haskey(functions(mod), intr_fn)
563-
end |> collect
564-
nargs = length(used_intrinsics)
565-
566-
# determine which functions need these arguments
567-
worklist = Set{LLVM.Function}([entry])
568-
for intr_fn in used_intrinsics
569-
push!(worklist, functions(mod)[intr_fn])
570-
end
571-
worklist_length = 0
572-
while worklist_length != length(worklist)
573-
# iteratively discover functions that use an intrinsic or any function calling it
574-
worklist_length = length(worklist)
575-
additions = LLVM.Function[]
576-
for f in worklist, use in uses(f)
577-
inst = user(use)::Instruction
578-
bb = LLVM.parent(inst)
579-
new_f = LLVM.parent(bb)
580-
in(new_f, worklist) || push!(additions, new_f)
581-
end
582-
for f in additions
583-
push!(worklist, f)
584-
end
585-
end
586-
for intr_fn in used_intrinsics
587-
delete!(worklist, functions(mod)[intr_fn])
588-
end
589-
590-
# add the arguments
591-
# NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
592-
workmap = Dict{LLVM.Function, LLVM.Function}()
593-
for f in worklist
594-
fn = LLVM.name(f)
595-
ft = function_type(f)
596-
LLVM.name!(f, fn * ".orig")
597-
# create a new function
598-
new_param_types = LLVMType[parameters(ft)...]
599-
600-
for intr_fn in used_intrinsics
601-
llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ)
602-
push!(new_param_types, llvm_typ)
603-
end
604-
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
605-
new_f = LLVM.Function(mod, fn, new_ft)
606-
linkage!(new_f, linkage(f))
607-
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
608-
LLVM.name!(new_arg, LLVM.name(arg))
609-
end
610-
for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
611-
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
612-
end
613-
614-
workmap[f] = new_f
615-
end
616-
617-
# clone and rewrite the function bodies.
618-
# we don't need to rewrite much as the arguments are added last.
619-
for (f, new_f) in workmap
620-
# map the arguments
621-
value_map = Dict{LLVM.Value, LLVM.Value}()
622-
for (param, new_param) in zip(parameters(f), parameters(new_f))
623-
LLVM.name!(new_param, LLVM.name(param))
624-
value_map[param] = new_param
625-
end
626-
627-
value_map[f] = new_f
628-
clone_into!(new_f, f; value_map,
629-
changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly)
630-
631-
# we can't remove this function yet, as we might still need to rewrite any called,
632-
# but remove the IR already
633-
empty!(f)
634-
end
635-
636-
# drop unused constants that may be referring to the old functions
637-
# XXX: can we do this differently?
638-
for f in worklist
639-
prune_constexpr_uses!(f)
640-
end
641-
642-
# update other uses of the old function, modifying call sites to pass the arguments
643-
function rewrite_uses!(f, new_f)
644-
# update uses
645-
@dispose builder=IRBuilder() begin
646-
for use in uses(f)
647-
val = user(use)
648-
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
649-
callee_f = LLVM.parent(LLVM.parent(val))
650-
# forward the arguments
651-
position!(builder, val)
652-
new_val = if val isa LLVM.CallInst
653-
call!(builder, function_type(new_f), new_f,
654-
[arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
655-
operand_bundles(val))
656-
else
657-
# TODO: invoke and callbr
658-
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
659-
end
660-
callconv!(new_val, callconv(val))
661-
662-
replace_uses!(val, new_val)
663-
@assert isempty(uses(val))
664-
erase!(val)
665-
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
666-
# XXX: why isn't this caught by the value materializer above?
667-
target = operands(val)[1]
668-
@assert target == f
669-
new_val = LLVM.const_bitcast(new_f, value_type(val))
670-
rewrite_uses!(val, new_val)
671-
# we can't simply replace this constant expression, as it may be used
672-
# as a call, taking arguments (so we need to rewrite it to pass the input arguments)
673-
674-
# drop the old constant if it is unused
675-
# XXX: can we do this differently?
676-
if isempty(uses(val))
677-
LLVM.unsafe_destroy!(val)
678-
end
679-
else
680-
error("Cannot rewrite unknown use of function: $val")
681-
end
682-
end
683-
end
684-
end
685-
for (f, new_f) in workmap
686-
rewrite_uses!(f, new_f)
687-
@assert isempty(uses(f))
688-
erase!(f)
689-
end
690-
691-
# replace uses of the intrinsics with references to the input arguments
692-
for (i, intr_fn) in enumerate(used_intrinsics)
693-
intr = functions(mod)[intr_fn]
694-
for use in uses(intr)
695-
val = user(use)
696-
callee_f = LLVM.parent(LLVM.parent(val))
697-
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
698-
replace_uses!(val, parameters(callee_f)[end-nargs+i])
699-
else
700-
error("Cannot rewrite unknown use of function: $val")
701-
end
702-
703-
@assert isempty(uses(val))
704-
erase!(val)
705-
end
706-
@assert isempty(uses(intr))
707-
erase!(intr)
708-
end
709-
710-
return
711-
end
712-
713-
714554
# argument metadata generation
715555
#
716556
# module metadata is used to identify buffers that are passed as kernel arguments.

0 commit comments

Comments
 (0)