Skip to content

Commit a031d6c

Browse files
committed
Use a better cleanup pass.
1 parent 6a9582f commit a031d6c

File tree

4 files changed

+38
-15
lines changed

4 files changed

+38
-15
lines changed

src/driver.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -307,17 +307,6 @@ const __llvm_initialized = Ref(false)
307307
end
308308
end
309309

310-
# remove the kernel state dummy use
311-
if haskey(functions(ir), "julia.gpu.state_user")
312-
dummy_user = functions(ir)["julia.gpu.state_user"]
313-
for use in uses(dummy_user)
314-
call = user(use)
315-
unsafe_delete!(LLVM.parent(call), call)
316-
end
317-
@assert isempty(uses(dummy_user))
318-
unsafe_delete!(ir, dummy_user)
319-
end
320-
321310
if ccall(:jl_is_debugbuild, Cint, ()) == 1
322311
@timeit_debug to "verification" verify(ir)
323312
end

src/irgen.jl

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,7 @@ function lower_kernel_state!(fun::LLVM.Function)
728728
# check if we even need a kernel state argument
729729
if !job.source.kernel
730730
# only kernels have had a kernel state argument added
731+
# XXX: for consistency, also add the state to non-kernel compilation jobs?
731732
return false
732733
end
733734
state = kernel_state_type(job)
@@ -779,6 +780,8 @@ function lower_kernel_state!(fun::LLVM.Function)
779780
for use in uses(state_intr)
780781
inst = user(use)
781782
@assert inst isa LLVM.CallInst
783+
bb = LLVM.parent(inst)
784+
LLVM.parent(bb) == fun || continue
782785

783786
position!(builder, inst)
784787
bb = LLVM.parent(inst)
@@ -794,9 +797,35 @@ function lower_kernel_state!(fun::LLVM.Function)
794797
end
795798
end
796799

797-
# clean-up
798-
@assert isempty(uses(state_intr))
799-
unsafe_delete!(mod, state_intr)
800+
return changed
801+
end
802+
803+
function cleanup_kernel_state!(mod::LLVM.Module)
804+
job = current_job::CompilerJob
805+
ctx = context(mod)
806+
changed = false
807+
808+
# remove the getter intrinsic
809+
if haskey(functions(mod), "julia.gpu.state_getter")
810+
intr = functions(mod)["julia.gpu.state_getter"]
811+
if isempty(uses(intr))
812+
# if we're not emitting a kernel, we can't resolve the intrinsic to an argument.
813+
unsafe_delete!(mod, intr)
814+
changed = true
815+
end
816+
end
817+
818+
# remove the kernel state dummy use
819+
if haskey(functions(mod), "julia.gpu.state_user")
820+
intr = functions(mod)["julia.gpu.state_user"]
821+
for use in uses(intr)
822+
call = user(use)
823+
unsafe_delete!(LLVM.parent(call), call)
824+
end
825+
@assert isempty(uses(intr))
826+
unsafe_delete!(mod, intr)
827+
changed = true
828+
end
800829

801830
return changed
802831
end

src/optim.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
185185
# GC lowering is the last pass that may introduce calls to the runtime library,
186186
# and thus additional uses of the kernel state.
187187
add!(pm, FunctionPass("LowerKernelState", lower_kernel_state!))
188+
add!(pm, ModulePass("CleanupKernelState", cleanup_kernel_state!))
188189

189190
# remove dead uses of ptls
190191
aggressive_dce!(pm)

test/ptx.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ end
9292
# state should only passed to device functions that use it
9393

9494
@eval @noinline kernel_state_child1(ptr) = unsafe_load(ptr)
95-
@eval @noinline kernel_state_child2() = ptx_kernel_state().data
95+
@eval @noinline function kernel_state_child2()
96+
data = ptx_kernel_state().data
97+
ptr = reinterpret(Ptr{Int}, data)
98+
unsafe_load(ptr)
99+
end
96100

97101
function kernel(ptr)
98102
unsafe_store!(ptr, kernel_state_child1(ptr) + kernel_state_child2())

0 commit comments

Comments
 (0)