Skip to content

Commit aac7d0f

Browse files
committed
some cleanup
1 parent 765bf60 commit aac7d0f

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

src/irgen.jl

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,8 @@ function add_kernel_state!(mod::LLVM.Module)
530530

531531
# additional arguments to pass to every function, but only if they are required
532532
additional_args = haskey(functions(mod), "julia.gpu.additional_arg_getter") ? additional_arg_types(job) : (;)
533-
T_additional_args = convert.(LLVMType, values(additional_args))
534-
names_additional_args = String.(keys(additional_args))
533+
T_additional_args = LLVMType[convert(LLVMType, T) for T in values(additional_args)]
534+
names_additional_args = String[String(name) for name in keys(additional_args)]
535535

536536
additional_arg_intrs = additional_arg_intr.(Ref(mod), T_additional_args)
537537
additional_arg_intr_fts = LLVM.FunctionType.(T_additional_args)
@@ -804,13 +804,12 @@ function lower_kernel_state!(fun::LLVM.Function)
804804

805805
i = Int(convert(Int, operands(inst)[1]::ConstantInt))
806806
if additional_args[i] === nothing
807-
state_arg = parameters(fun)[end - length(additional_arg_tys) + i]
808-
T_state = convert(LLVMType, additional_arg_tys[i])
809-
@assert value_type(state_arg) == T_state
810-
additional_args[i] = state_arg
807+
additional_args[i] = parameters(fun)[end - length(additional_arg_tys) + i]
808+
T_arg = convert(LLVMType, additional_arg_tys[i])
809+
@assert value_type(additional_args[i]) == T_arg
811810
end
812811

813-
replace_uses!(inst, state_arg)
812+
replace_uses!(inst, additional_args[i])
814813

815814
@assert isempty(uses(inst))
816815
erase!(inst)
@@ -841,7 +840,6 @@ function cleanup_kernel_state!(mod::LLVM.Module)
841840
if haskey(functions(mod), "julia.gpu.additional_arg_getter")
842841
intr = functions(mod)["julia.gpu.additional_arg_getter"]
843842
if isempty(uses(intr))
844-
# if we're not emitting a kernel, we can't resolve the intrinsic to an argument.
845843
erase!(intr)
846844
changed = true
847845
end
@@ -985,40 +983,40 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
985983
end
986984
end
987985

988-
function additional_arg_intr(mod::LLVM.Module, T_state)
989-
state_intr = if haskey(functions(mod), "julia.gpu.additional_arg_getter")
986+
function additional_arg_intr(mod::LLVM.Module, T_arg)
987+
additional_arg_intr = if haskey(functions(mod), "julia.gpu.additional_arg_getter")
990988
functions(mod)["julia.gpu.additional_arg_getter"]
991989
else
992-
LLVM.Function(mod, "julia.gpu.additional_arg_getter", LLVM.FunctionType(T_state))
990+
LLVM.Function(mod, "julia.gpu.additional_arg_getter", LLVM.FunctionType(T_arg))
993991
end
994-
push!(function_attributes(state_intr), EnumAttribute("readnone", 0))
992+
push!(function_attributes(additional_arg_intr), EnumAttribute("readnone", 0))
995993

996-
return state_intr
994+
return additional_arg_intr
997995
end
998996

999997
# run-time equivalent
1000-
function additional_arg_value(state, index)
998+
function additional_arg_value(arg, index)
1001999
@dispose ctx=Context() begin
1002-
T_state = convert(LLVMType, state)
1000+
T_arg = convert(LLVMType, arg)
10031001

10041002
# create function
1005-
llvm_f, _ = create_function(T_state)
1003+
llvm_f, _ = create_function(T_arg)
10061004
mod = LLVM.parent(llvm_f)
10071005

10081006
# get intrinsic
1009-
state_intr = additional_arg_intr(mod, T_state)
1010-
state_intr_ft = function_type(state_intr)
1007+
_additional_arg_intr = additional_arg_intr(mod, T_arg)
1008+
additional_arg_intr_ft = function_type(_additional_arg_intr)
10111009

10121010
# generate IR
10131011
@dispose builder=IRBuilder() begin
10141012
entry = BasicBlock(llvm_f, "entry")
10151013
position!(builder, entry)
10161014

1017-
val = call!(builder, state_intr_ft, state_intr, Value[ConstantInt(index)], "additional_arg")
1015+
val = call!(builder, additional_arg_intr_ft, _additional_arg_intr, Value[ConstantInt(index)], "additional_arg")
10181016

10191017
ret!(builder, val)
10201018
end
10211019

1022-
call_function(llvm_f, state)
1020+
call_function(llvm_f, arg)
10231021
end
10241022
end

0 commit comments

Comments
 (0)