Skip to content

Commit efa4602

Browse files
committed
Only add the kernel state to functions that need it.
Some libraries are hard to rewrite.
1 parent 8f4cb60 commit efa4602

File tree

1 file changed

+33
-7
lines changed

1 file changed

+33
-7
lines changed

src/irgen.jl

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -588,8 +588,30 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
588588
# this is both for extern uses, and to make this transformation a two-step process.
589589
state_intr = kernel_state_intr(mod, T_state)
590590

591-
# add a state argument to every function
592-
worklist = filter(!isdeclaration, collect(functions(mod)))
591+
# determine which functions need a kernel state argument
592+
#
593+
# previously, we add the argument to every function and relied on unused arg elim to
594+
# clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting
595+
# function pointers. such IR is hard to rewrite, so instead be more conservative.
596+
worklist = Set{LLVM.Function}([entry, state_intr])
597+
worklist_length = 0
598+
while worklist_length != length(worklist)
599+
# iteratively discover functions that use the intrinsic or any function calling it
600+
worklist_length = length(worklist)
601+
additions = LLVM.Function[]
602+
for f in worklist, use in uses(f)
603+
inst = user(use)::Instruction
604+
bb = LLVM.parent(inst)
605+
new_f = LLVM.parent(bb)
606+
in(new_f, worklist) || push!(additions, new_f)
607+
end
608+
for f in additions
609+
push!(worklist, f)
610+
end
611+
end
612+
delete!(worklist, state_intr)
613+
614+
# add a state argument
593615
workmap = Dict{LLVM.Function, LLVM.Function}()
594616
for f in worklist
595617
fn = LLVM.name(f)
@@ -726,14 +748,10 @@ function lower_kernel_state!(fun::LLVM.Function)
726748
return false
727749
end
728750

729-
# find the kernel state argument. this should be the first argument of the function.
730-
state_arg = parameters(fun)[1]
731-
T_state = convert(LLVMType, state; ctx)
732-
@assert llvmtype(state_arg) == T_state
733-
734751
# fixup all uses of the state getter to use the newly introduced function state argument
735752
if haskey(functions(mod), "julia.gpu.state_getter")
736753
state_intr = functions(mod)["julia.gpu.state_getter"]
754+
state_arg = nothing # only look-up when needed
737755

738756
Builder(ctx) do builder
739757
for use in uses(state_intr)
@@ -746,6 +764,14 @@ function lower_kernel_state!(fun::LLVM.Function)
746764
bb = LLVM.parent(inst)
747765
f = LLVM.parent(bb)
748766

767+
if state_arg === nothing
768+
# find the kernel state argument. this should be the first argument of
769+
# the function, but only when this function needs the state!
770+
state_arg = parameters(fun)[1]
771+
T_state = convert(LLVMType, state; ctx)
772+
@assert llvmtype(state_arg) == T_state
773+
end
774+
749775
replace_uses!(inst, state_arg)
750776

751777
@assert isempty(uses(inst))

0 commit comments

Comments
 (0)