Skip to content

Commit fb3088c

Browse files
committed
Add optimization callbacks that fire on a marker function
1 parent b37223e commit fb3088c

File tree

3 files changed

+58
-1
lines changed

3 files changed

+58
-1
lines changed

src/optim.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=1)
44
tm = llvm_machine(job.config.target)
55

6-
global current_job
6+
global current_job # ScopedValue?
77
current_job = job
88

99
@dispose pb=NewPMPassBuilder() begin
@@ -24,6 +24,15 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
2424
return
2525
end
2626

27+
# TODO: Priority heap to provide order between different plugins
28+
const PIPELINE_CALLBACKS = Dict{String, Any}()
29+
function register_plugin!(name::String, plugin)
30+
if haskey(PIPELINE_CALLBACKS, name)
31+
error("GPUCompiler plugin with name $name is already registered")
32+
end
33+
PIPELINE_CALLBACKS[name] = plugin
34+
end
35+
2736
function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level)
2837
buildEarlySimplificationPipeline(mpm, job, opt_level)
2938
add!(mpm, AlwaysInlinerPass())
@@ -41,6 +50,9 @@ function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level)
4150
add!(fpm, WarnMissedTransformationsPass())
4251
end
4352
end
53+
for (name, callback) in PIPELINE_CALLBACKS
54+
add!(mpm, CallbackPass(name, callback))
55+
end
4456
buildIntrinsicLoweringPipeline(mpm, job, opt_level)
4557
buildCleanupPipeline(mpm, job, opt_level)
4658
end
@@ -423,3 +435,17 @@ function lower_ptls!(mod::LLVM.Module)
423435
return changed
424436
end
425437
LowerPTLSPass() = NewPMModulePass("GPULowerPTLS", lower_ptls!)
438+
439+
440+
function callback_pass!(name, callback::F, mod::LLVM.Module) where F
441+
job = current_job::CompilerJob
442+
changed = false
443+
444+
if haskey(functions(mod), name)
445+
marker = functions(mod)[name]
446+
changed = callback(job, marker, mod)
447+
end
448+
return changed
449+
end
450+
451+
CallbackPass(name, callback) = NewPMModulePass("CallbackPass<$name>", (mod)->callback_pass!(name, callback, mod))

test/ptx_tests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,15 @@ end
276276
@test "We did not crash!" != ""
277277
end
278278

279+
@testset "Pipeline callbacks" begin
280+
function kernel(x)
281+
PTX.mark(x)
282+
return
283+
end
284+
ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Int}))
285+
@test !occursin("gpucompuler.mark", ir)
286+
end
287+
279288
@testset "exception arguments" begin
280289
function kernel(a)
281290
unsafe_store!(a, trunc(Int, unsafe_load(a)))

test/ptx_testsetup.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,28 @@ end
1616
GPUCompiler.kernel_state_type(@nospecialize(job::PTXCompilerJob)) = PTXKernelState
1717
@inline @generated kernel_state() = GPUCompiler.kernel_state_value(PTXKernelState)
1818

19+
function mark(x)
20+
ccall("gpucompiler.mark", llvcmall, Nothing, (Int,), x)
21+
end
22+
23+
function remove_mark!(@nospecialize(job), intrinsic, mod::LLVM.Module)
24+
changed = false
25+
26+
for use in uses(intrinsic)
27+
val = user(use)
28+
if isempty(uses(val))
29+
unsafe_delete!(LLVM.parent(val), val)
30+
changed = true
31+
else
32+
# the validator will detect this
33+
end
34+
end
35+
36+
return changed
37+
end
38+
39+
GPUCompiler.register_plugin!("gpucompiler.mark", remove_mark!)
40+
1941
# a version of the test runtime that has some side effects, loading the kernel state
2042
# (so that we can test if kernel state arguments are appropriately optimized away)
2143
module PTXTestRuntime

0 commit comments

Comments
 (0)