Skip to content

Commit 828ee63

Browse files
committed
Add optimization callbacks that fire on a marker function
1 parent 5281e86 commit 828ee63

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed

src/optim.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=1)
44
tm = llvm_machine(job.config.target)
55

6-
global current_job
6+
global current_job # ScopedValue?
77
current_job = job
88

99
@dispose pb=NewPMPassBuilder() begin
@@ -14,6 +14,10 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
1414
register!(pb, LowerKernelStatePass())
1515
register!(pb, CleanupKernelStatePass())
1616

17+
for (name, callback) in PIPELINE_CALLBACKS
18+
register!(pb, CallbackPass(name, callback))
19+
end
20+
1721
add!(pb, NewPMModulePassManager()) do mpm
1822
buildNewPMPipeline!(mpm, job, opt_level)
1923
end
@@ -24,6 +28,15 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
2428
return
2529
end
2630

31+
# TODO: Priority heap to provide order between different plugins
32+
const PIPELINE_CALLBACKS = Dict{String, Any}()
33+
function register_plugin!(name::String, plugin)
34+
if haskey(PIPELINE_CALLBACKS, name)
35+
error("GPUCompiler plugin with name $name is already registered")
36+
end
37+
PIPELINE_CALLBACKS[name] = plugin
38+
end
39+
2740
function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level)
2841
buildEarlySimplificationPipeline(mpm, job, opt_level)
2942
add!(mpm, AlwaysInlinerPass())
@@ -41,6 +54,9 @@ function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level)
4154
add!(fpm, WarnMissedTransformationsPass())
4255
end
4356
end
57+
for (name, callback) in PIPELINE_CALLBACKS
58+
add!(mpm, CallbackPass(name, callback))
59+
end
4460
buildIntrinsicLoweringPipeline(mpm, job, opt_level)
4561
buildCleanupPipeline(mpm, job, opt_level)
4662
end
@@ -423,3 +439,17 @@ function lower_ptls!(mod::LLVM.Module)
423439
return changed
424440
end
425441
LowerPTLSPass() = NewPMModulePass("GPULowerPTLS", lower_ptls!)
442+
443+
444+
function callback_pass!(name, callback::F, mod::LLVM.Module) where F
445+
job = current_job::CompilerJob
446+
changed = false
447+
448+
if haskey(functions(mod), name)
449+
marker = functions(mod)[name]
450+
changed = callback(job, marker, mod)
451+
end
452+
return changed
453+
end
454+
455+
CallbackPass(name, callback) = NewPMModulePass("CallbackPass<$name>", (mod)->callback_pass!(name, callback, mod))

test/ptx_tests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@testitem "PTX" setup=[PTX, Helpers] begin
22

33
using LLVM
4+
import InteractiveUtils
45

56
############################################################################################
67

@@ -276,6 +277,17 @@ end
276277
@test "We did not crash!" != ""
277278
end
278279

280+
@testset "Pipeline callbacks" begin
281+
function kernel(x)
282+
PTX.mark(x)
283+
return
284+
end
285+
ir = sprint(io->InteractiveUtils.code_llvm(io, kernel, Tuple{Int}))
286+
@test occursin("gpucompiler.mark", ir)
287+
ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Int}))
288+
@test !occursin("gpucompiler.mark", ir)
289+
end
290+
279291
@testset "exception arguments" begin
280292
function kernel(a)
281293
unsafe_store!(a, trunc(Int, unsafe_load(a)))

test/ptx_testsetup.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
@testsetup module PTX
22

33
using GPUCompiler
4-
4+
import LLVM
55

66
# create a PTX-based test compiler, and generate reflection methods for it
77

@@ -16,6 +16,28 @@ end
1616
GPUCompiler.kernel_state_type(@nospecialize(job::PTXCompilerJob)) = PTXKernelState
1717
@inline @generated kernel_state() = GPUCompiler.kernel_state_value(PTXKernelState)
1818

19+
function mark(x)
20+
ccall("extern gpucompiler.mark", llvmcall, Nothing, (Int,), x)
21+
end
22+
23+
function remove_mark!(@nospecialize(job), intrinsic, mod::LLVM.Module)
24+
changed = false
25+
26+
for use in LLVM.uses(intrinsic)
27+
val = LLVM.user(use)
28+
if isempty(LLVM.uses(val))
29+
LLVM.unsafe_delete!(LLVM.parent(val), val)
30+
changed = true
31+
else
32+
# the validator will detect this
33+
end
34+
end
35+
36+
return changed
37+
end
38+
39+
GPUCompiler.register_plugin!("gpucompiler.mark", remove_mark!)
40+
1941
# a version of the test runtime that has some side effects, loading the kernel state
2042
# (so that we can test if kernel state arguments are appropriately optimized away)
2143
module PTXTestRuntime

0 commit comments

Comments
 (0)