Skip to content

Commit 6a9582f

Browse files
committed
Lower uses of the state getter intrinsic late during optimization.
1 parent 2cbc107 commit 6a9582f

File tree

3 files changed

+115
-71
lines changed

3 files changed

+115
-71
lines changed

src/irgen.jl

Lines changed: 104 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -545,20 +545,22 @@ end
545545

546546
# kernel state arguments
547547
#
548-
# add a state argument every function in the module, and lower calls to the
549-
# `julia.gpu.state_getter` intrinsics to use this newly-introduced state argument.
550-
#
551-
# the type of the state is determined by the `kernel_state_type` interface, and is passed
552-
# as a byval pointer so that (1) the intrinsic can use an opaque pointer for users to
553-
# cast to an appropriate type, while (2) ensuring the state resides in thread-local memory
554-
# so that it can be used without synchronizing global-memory accesses.
548+
# to facilitate passing stateful information to kernels without having to recompile, e.g.,
549+
# the storage location for exception flags, or the location of a I/O buffer, we enable the
550+
# back-end to specify a Julia object that will be passed to the kernel by-value, and to
551+
# every called function by-reference. Access to this object is done using the
552+
# `julia.gpu.state_getter` intrinsic, which returns an opaque pointer to the state object.
553+
# after optimization, these intrinsics will be lowered to refer to the state argument.
554+
555+
# add a state argument to every function in the module, starting from the kernel entry point
555556
function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
556557
entry::LLVM.Function)
557558
ctx = context(mod)
558559
entry_fn = LLVM.name(entry)
559560

560561
# check if we even need a kernel state argument
561562
state = kernel_state_type(job)
563+
@assert job.source.kernel
562564
if state === Nothing
563565
return false
564566
end
@@ -569,12 +571,7 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
569571
# this is both for extern uses, and to make this transformation a two-step process.
570572
T_int8 = LLVM.IntType(8; ctx)
571573
T_pint8 = LLVM.PointerType(T_int8)
572-
state_intr = if haskey(functions(mod), "julia.gpu.state_getter")
573-
functions(mod)["julia.gpu.state_getter"]
574-
else
575-
LLVM.Function(mod, "julia.gpu.state_getter", LLVM.FunctionType(T_pint8))
576-
end
577-
push!(function_attributes(state_intr), EnumAttribute("readnone", 0; ctx))
574+
state_intr = kernel_state_intr(mod)
578575

579576
# add a state argument to every function
580577
worklist = filter(!isdeclaration, collect(functions(mod)))
@@ -659,10 +656,10 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
659656

660657
# forward the state argument
661658
position!(builder, val)
662-
state = call!(builder, state_intr, Value[], "state")
663-
state = bitcast!(builder, state, T_ptr_state)
659+
untyped_state = call!(builder, state_intr, Value[], "state")
660+
typed_state = bitcast!(builder, untyped_state, T_ptr_state)
664661
new_val = if val isa LLVM.CallInst
665-
call!(builder, new_f, [state, operands(val)[1:end-1]...])
662+
call!(builder, new_f, [typed_state, operands(val)[1:end-1]...])
666663
else
667664
# TODO: invoke and callbr
668665
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
@@ -698,27 +695,7 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
698695
unsafe_delete!(mod, f)
699696
end
700697

701-
# fixup all uses of the state getter to use the newly introduced function state argument
702-
Builder(ctx) do builder
703-
for use in uses(state_intr)
704-
inst = user(use)
705-
@assert inst isa LLVM.CallInst
706-
707-
position!(builder, inst)
708-
bb = LLVM.parent(inst)
709-
f = LLVM.parent(bb)
710-
711-
state = parameters(f)[1]
712-
state = bitcast!(builder, state, T_pint8)
713-
replace_uses!(inst, state)
714-
715-
@assert isempty(uses(inst))
716-
unsafe_delete!(LLVM.parent(inst), inst)
717-
end
718-
end
719-
720-
# HACK: add a dummy use of the kernel state pointer to ensure it is always available
721-
# also see `kernel_state_argument` below.
698+
# HACK: add a dummy use of the kernel state pointer to ensure it survives optimization
722699
dummy_user = if haskey(functions(mod), "julia.gpu.state_user")
723700
functions(mod)["julia.gpu.state_user"]
724701
else
@@ -731,54 +708,112 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
731708
call!(builder, dummy_user, [parameters(entry)[1]])
732709
end
733710

734-
# clean-up
735-
@assert isempty(uses(state_intr))
736-
unsafe_delete!(mod, state_intr)
737-
738-
# don't pass the state when unnecessary
739-
# XXX: isn't this done during optimization as well?
740-
ModulePassManager() do pm
741-
dead_arg_elimination!(pm)
742-
run!(pm, mod)
743-
end
744-
745711
return true
746712
end
747713

748-
# return a value pointing to the state argument in a given function.
749-
function kernel_state_argument(f::LLVM.Function, state::Type)
750-
ctx = context(f)
751-
mod = LLVM.parent(f)
714+
# lower calls to the state getter intrinsic. this is a two-step process, so that the state
715+
# argument can be added before optimization, and that optimization can introduce new uses
716+
# before the intrinsic getting lowered late during optimization.
717+
#
718+
# the reason we want to add the state argument before optimization, is that the initial
719+
# argument is marked byval, but some backends need to eagerly lower that byval property
720+
# (because the LLVM back-end doesn't support emitting code for it). That lowering typically
721+
# generates a lot of expensive code, so _needs_ to be optimized.
722+
function lower_kernel_state!(fun::LLVM.Function)
723+
job = current_job::CompilerJob
724+
mod = LLVM.parent(fun)
725+
ctx = context(fun)
726+
changed = false
727+
728+
# check if we even need a kernel state argument
729+
if !job.source.kernel
730+
# only kernels have had a kernel state argument added
731+
return false
732+
end
733+
state = kernel_state_type(job)
734+
if state === Nothing
735+
return false
736+
end
752737

738+
# find the kernel state argument. normally, this is the first argument of the function.
739+
state_arg = nothing
753740
T_state = convert(LLVMType, state; ctx)
754741
T_ptr_state = LLVM.PointerType(T_state)
755-
756-
arg = parameters(f)[1]
757-
if llvmtype(arg) == T_ptr_state
758-
return arg
742+
first_arg = parameters(fun)[1]
743+
if llvmtype(first_arg) == T_ptr_state
744+
state_arg = first_arg
759745
end
760746

761-
# if the first argument isn't a valid kernel state pointer, this probably means we're
762-
# in a kernel function whose byval-annotated kernel state argument got lowered eagerly.
763-
# to make sure we can still get a pointer to the kernel state, we've emitted a dummy
764-
# use, which we can use here to get a pointer to the kernel state.
747+
# with kernels, the story is more complicated: the kernel state argument is marked byval,
748+
# and it's possible we eagerly lowered that pointer to a value. to retrieve the state,
749+
# look for the alloca slot the argument was stored in via the dummy use we introduced.
765750
#
766751
# this is obviously a hack, stemming from the fact that while lowering Julia intrinsics
767752
# (which needs to happen _after_ optimization) we may have to emit calls to the GPU
768753
# runtime while those functions may already have had their kernel state arguments added
769754
# (which we do _before_ optimization to make sure that any lowered byval performs well).
770-
@assert llvmtype(arg) == T_state
771-
dummy_user = functions(mod)["julia.gpu.state_user"]
772-
for use in uses(dummy_user)
773-
call = user(use)
774-
bb = LLVM.parent(call)
775-
if LLVM.parent(bb) == f
776-
arg = operands(call)[1]
777-
return arg
755+
if state_arg === nothing
756+
@assert llvmtype(first_arg) == T_state
757+
dummy_user = functions(mod)["julia.gpu.state_user"]
758+
for use in uses(dummy_user)
759+
call = user(use)
760+
bb = LLVM.parent(call)
761+
if LLVM.parent(bb) == fun
762+
state_arg = operands(call)[1]
763+
break
764+
end
765+
end
766+
end
767+
768+
if state_arg === nothing
769+
error("Internal compiler error: could not reconstruct kernel state argument")
770+
end
771+
772+
# get the intrinsic returning an opaque pointer to the kernel state.
773+
T_int8 = LLVM.IntType(8; ctx)
774+
T_pint8 = LLVM.PointerType(T_int8)
775+
state_intr = kernel_state_intr(mod)
776+
777+
# fixup all uses of the state getter to use the newly introduced function state argument
778+
Builder(ctx) do builder
779+
for use in uses(state_intr)
780+
inst = user(use)
781+
@assert inst isa LLVM.CallInst
782+
783+
position!(builder, inst)
784+
bb = LLVM.parent(inst)
785+
f = LLVM.parent(bb)
786+
787+
untyped_state = bitcast!(builder, state_arg, T_pint8)
788+
replace_uses!(inst, untyped_state)
789+
790+
@assert isempty(uses(inst))
791+
unsafe_delete!(LLVM.parent(inst), inst)
792+
793+
changed = true
778794
end
779795
end
780796

781-
error("Internal compiler error: could not reconstruct kernel state argument")
797+
# clean-up
798+
@assert isempty(uses(state_intr))
799+
unsafe_delete!(mod, state_intr)
800+
801+
return changed
802+
end
803+
804+
function kernel_state_intr(mod::LLVM.Module)
805+
ctx = context(mod)
806+
T_int8 = LLVM.IntType(8; ctx)
807+
T_pint8 = LLVM.PointerType(T_int8)
808+
809+
state_intr = if haskey(functions(mod), "julia.gpu.state_getter")
810+
functions(mod)["julia.gpu.state_getter"]
811+
else
812+
LLVM.Function(mod, "julia.gpu.state_getter", LLVM.FunctionType(T_pint8))
813+
end
814+
push!(function_attributes(state_intr), EnumAttribute("readnone", 0; ctx))
815+
816+
return state_intr
782817
end
783818

784819
# run-time equivalent (untyped)

src/optim.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
182182

183183
add!(pm, FunctionPass("LowerGCFrame", lower_gc_frame!))
184184

185+
# GC lowering is the last pass that may introduce calls to the runtime library,
186+
# and thus additional uses of the kernel state.
187+
add!(pm, FunctionPass("LowerKernelState", lower_kernel_state!))
188+
185189
# remove dead uses of ptls
186190
aggressive_dce!(pm)
187191
add!(pm, ModulePass("LowerPTLS", lower_ptls!))

src/rtlib.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,13 @@ function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[
4848
# in which case we need to manually get and pass the state.
4949
args = Value[args...]
5050
if state !== Nothing
51-
state_val = kernel_state_argument(f, state)
52-
pushfirst!(args, state_val)
51+
T_state = convert(LLVMType, state; ctx)
52+
T_ptr_state = LLVM.PointerType(T_state)
53+
54+
state_intr = kernel_state_intr(mod)
55+
untyped_state = call!(builder, state_intr, Value[], "state")
56+
typed_state = bitcast!(builder, untyped_state, T_ptr_state)
57+
pushfirst!(args, typed_state)
5358
end
5459

5560
# runtime functions are written in Julia, while we're calling from LLVM,

0 commit comments

Comments
 (0)