Skip to content

Commit 3e53a3f

Browse files
committed
Rework kernel state handling.
- actually treat the getter intrinsic as opaque - emit a dummy use of the state pointer so that we can be sure to be able to recover it after optimization - mark the intrinsic as readnone
1 parent c57af09 commit 3e53a3f

File tree

3 files changed

+105
-31
lines changed

3 files changed

+105
-31
lines changed

src/driver.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,17 @@ const __llvm_initialized = Ref(false)
299299
entry = functions(ir)[entry_fn]
300300
end
301301

302+
# remove the kernel state dummy use
303+
if haskey(functions(ir), "julia.gpu.state_user")
304+
dummy_user = functions(ir)["julia.gpu.state_user"]
305+
for use in uses(dummy_user)
306+
call = user(use)
307+
unsafe_delete!(LLVM.parent(call), call)
308+
end
309+
@assert isempty(uses(dummy_user))
310+
unsafe_delete!(ir, dummy_user)
311+
end
312+
302313
if ccall(:jl_is_debugbuild, Cint, ()) == 1
303314
@timeit_debug to "verification" verify(ir)
304315
end

src/interface.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -205,20 +205,8 @@ function finish_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry
205205
entry_fn = LLVM.name(entry)
206206

207207
# add the kernel state, and lower calls to the `julia.gpu.state_getter` intrinsic.
208-
# we do this _after_ optimization, because the runtime is linked after optimization too.
209208
if job.source.kernel
210-
state = kernel_state_type(job)
211-
if state !== Nothing
212-
T_state = convert(LLVMType, state; ctx)
213-
add_kernel_state!(job, mod, entry, T_state)
214-
end
215-
216-
# don't pass the state when unnecessary
217-
# XXX: only apply in add_kernel_state! when needed?
218-
ModulePassManager() do pm
219-
dead_arg_elimination!(pm)
220-
run!(pm, mod)
221-
end
209+
add_kernel_state!(job, mod, entry)
222210
end
223211

224212
return functions(mod)[entry_fn]

src/irgen.jl

Lines changed: 93 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -553,18 +553,28 @@ end
553553
# cast to an appropriate type, while (2) ensuring the state resides in thread-local memory
554554
# so that it can be used without synchronizing global-memory accesses.
555555
function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
556-
entry::LLVM.Function, T_state::LLVMType)
556+
entry::LLVM.Function)
557557
ctx = context(mod)
558+
entry_fn = LLVM.name(entry)
559+
560+
# check if we even need a kernel state argument
561+
state = kernel_state_type(job)
562+
if state === Nothing
563+
return false
564+
end
565+
T_state = convert(LLVMType, state; ctx)
566+
T_ptr_state = LLVM.PointerType(T_state)
558567

559568
# intrinsic returning an opaque pointer to the kernel state.
560569
# this is both for extern uses, and to make this transformation a two-step process.
561-
T_ptr_state = LLVM.PointerType(T_state)
562-
state_getter = if haskey(functions(mod), "julia.gpu.state_getter")
570+
T_int8 = LLVM.IntType(8; ctx)
571+
T_pint8 = LLVM.PointerType(T_int8)
572+
state_intr = if haskey(functions(mod), "julia.gpu.state_getter")
563573
functions(mod)["julia.gpu.state_getter"]
564574
else
565-
LLVM.Function(mod, "julia.gpu.state_getter", LLVM.FunctionType(T_ptr_state))
575+
LLVM.Function(mod, "julia.gpu.state_getter", LLVM.FunctionType(T_int8))
566576
end
567-
push!(function_attributes(state_getter), EnumAttribute("readnone", 0; ctx))
577+
push!(function_attributes(state_intr), EnumAttribute("readnone", 0; ctx))
568578

569579
# add a state argument to every function
570580
worklist = filter(!isdeclaration, collect(functions(mod)))
@@ -649,7 +659,8 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
649659

650660
# forward the state argument
651661
position!(builder, val)
652-
state = call!(builder, state_getter, Value[], "state")
662+
state = call!(builder, state_intr, Value[], "state")
663+
state = bitcast!(builder, state, T_ptr_state)
653664
new_val = if val isa LLVM.CallInst
654665
call!(builder, new_f, [state, operands(val)[1:end-1]...])
655666
else
@@ -688,25 +699,89 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
688699
end
689700

690701
# fixup all uses of the state getter to use the newly introduced function state argument
691-
for use in uses(state_getter)
692-
inst = user(use)
693-
@assert inst isa LLVM.CallInst
702+
Builder(ctx) do builder
703+
for use in uses(state_intr)
704+
inst = user(use)
705+
@assert inst isa LLVM.CallInst
694706

695-
bb = LLVM.parent(inst)
696-
f = LLVM.parent(bb)
707+
position!(builder, inst)
708+
bb = LLVM.parent(inst)
709+
f = LLVM.parent(bb)
697710

698-
replace_uses!(inst, parameters(f)[1])
699-
@assert isempty(uses(inst))
700-
unsafe_delete!(LLVM.parent(inst), inst)
711+
state = parameters(f)[1]
712+
state = bitcast!(builder, state, T_int8)
713+
replace_uses!(inst, state)
714+
715+
@assert isempty(uses(inst))
716+
unsafe_delete!(LLVM.parent(inst), inst)
717+
end
718+
end
719+
720+
# HACK: add a dummy use of the kernel state pointer to ensure it is always available
721+
# also see `kernel_state_argument` below.
722+
dummy_user = if haskey(functions(mod), "julia.gpu.state_user")
723+
functions(mod)["julia.gpu.state_user"]
724+
else
725+
LLVM.Function(mod, "julia.gpu.state_user",
726+
LLVM.FunctionType(LLVM.VoidType(ctx), [T_ptr_state]))
727+
end
728+
entry = functions(mod)[entry_fn]
729+
Builder(ctx) do builder
730+
position!(builder, first(instructions(first(blocks(entry)))))
731+
call!(builder, dummy_user, [parameters(entry)[1]])
701732
end
702733

703734
# clean-up
704-
@assert isempty(uses(state_getter))
705-
unsafe_delete!(mod, state_getter)
735+
@assert isempty(uses(state_intr))
736+
unsafe_delete!(mod, state_intr)
737+
738+
# don't pass the state when unnecessary
739+
# XXX: isn't this done during optimization as well?
740+
ModulePassManager() do pm
741+
dead_arg_elimination!(pm)
742+
run!(pm, mod)
743+
end
744+
745+
return true
746+
end
747+
748+
# return a value pointing to the state argument in a given function.
749+
function kernel_state_argument(f::LLVM.Function, state::Type)
750+
ctx = context(f)
751+
mod = LLVM.parent(f)
752+
753+
T_state = convert(LLVMType, state; ctx)
754+
T_ptr_state = LLVM.PointerType(T_state)
755+
756+
arg = parameters(f)[1]
757+
if llvmtype(arg) == T_ptr_state
758+
return arg
759+
end
760+
761+
# if the first argument isn't a valid kernel state pointer, this probably means we're
762+
# in a kernel function whose byval-annotated kernel state argument got lowered eagerly.
763+
# to make sure we can still get a pointer to the kernel state, we've emitted a dummy
764+
# use, which we can use here to get a pointer to the kernel state.
765+
#
766+
# this is obviously a hack, stemming from the fact that while lowering Julia intrinsics
767+
# (which needs to happen _after_ optimization) we may have to emit calls to the GPU
768+
# runtime while those functions may already have had their kernel state arguments added
769+
# (which we do _before_ optimization to make sure that any lowered byval performs well).
770+
@assert llvmtype(arg) == T_state
771+
dummy_user = functions(mod)["julia.gpu.state_user"]
772+
for use in uses(dummy_user)
773+
call = user(use)
774+
bb = LLVM.parent(call)
775+
if LLVM.parent(bb) == f
776+
arg = operands(call)[1]
777+
return arg
778+
end
779+
end
706780

707-
return
781+
error("Internal compiler error: could not reconstruct kernel state argument")
708782
end
709783

784+
# run-time equivalent (untyped)
710785
@inline kernel_state_pointer() = Base.llvmcall(("""
711786
declare i8* @julia.gpu.state_getter()
712787
@@ -716,5 +791,5 @@ end
716791
ret i64 %ptr
717792
}
718793
719-
attributes #0 = { alwaysinline }""", "entry"),
794+
attributes #0 = { alwaysinline readnone }""", "entry"),
720795
Ptr{Cvoid}, Tuple{})

0 commit comments

Comments
 (0)