Skip to content

Commit 9b54632

Browse files
committed
Add optimization callbacks that fire on a marker function
1 parent ced39bb commit 9b54632

File tree

5 files changed

+94
-3
lines changed

5 files changed

+94
-3
lines changed

src/driver.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,12 @@ const __llvm_initialized = Ref(false)
260260
end
261261
end
262262

263+
for (name, plugin) in PLUGINS
264+
if plugin.finalize_module !== nothing
265+
plugin.finalize_module(job, compiled, ir)
266+
end
267+
end
268+
263269
@timeit_debug to "IR post-processing" begin
264270
# mark everything internal except for entrypoints and any exported
265271
# global variables. this makes sure that the optimizer can, e.g.,
@@ -335,7 +341,7 @@ const __llvm_initialized = Ref(false)
335341
# we want to finish the module after optimization, so we cannot do so
336342
# during deferred code generation. Instead, process the merged module
337343
# from all the jobs here.
338-
if toplevel
344+
if toplevel # TODO: We should be able to remove this now
339345
entry = finish_ir!(job, ir, entry)
340346

341347
# for (job′, fn′) in jobs

src/optim.jl

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=1)
44
tm = llvm_machine(job.config.target)
55

6-
global current_job
6+
global current_job # ScopedValue?
77
current_job = job
88

99
@dispose pb=NewPMPassBuilder() begin
@@ -14,6 +14,12 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
1414
register!(pb, LowerKernelStatePass())
1515
register!(pb, CleanupKernelStatePass())
1616

17+
for (name, plugin) in PLUGINS
18+
if plugin.pipeline_callback !== nothing
19+
register!(pb, CallbackPass(name, plugin.pipeline_callback))
20+
end
21+
end
22+
1723
add!(pb, NewPMModulePassManager()) do mpm
1824
buildNewPMPipeline!(mpm, job, opt_level)
1925
end
@@ -24,6 +30,20 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
2430
return
2531
end
2632

33+
struct Plugin
34+
finalize_module # f(@nospecialize(job), compiled, mod::LLVM,Module)
35+
pipeline_callback # f(@nospecialize(job), intrinsic, mod::LLVM.Module)
36+
end
37+
38+
# TODO: Priority heap to provide order between different plugins
39+
const PLUGINS = Dict{String, Plugin}()
40+
function register_plugin!(name::String, check::Bool=true; finalize_module = nothing, pipeline_callback = nothing)
41+
if check && haskey(PLUGINS, name)
42+
error("GPUCompiler plugin with name $name is already registered")
43+
end
44+
PLUGINS[name] = Plugin(finalize_module, pipeline_callback)
45+
end
46+
2747
function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level)
2848
buildEarlySimplificationPipeline(mpm, job, opt_level)
2949
add!(mpm, AlwaysInlinerPass())
@@ -41,6 +61,11 @@ function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level)
4161
add!(fpm, WarnMissedTransformationsPass())
4262
end
4363
end
64+
for (name, plugin) in PLUGINS
65+
if plugin.pipeline_callback !== nothing
66+
add!(mpm, CallbackPass(name, plugin.pipeline_callback))
67+
end
68+
end
4469
buildIntrinsicLoweringPipeline(mpm, job, opt_level)
4570
buildCleanupPipeline(mpm, job, opt_level)
4671
end
@@ -423,3 +448,17 @@ function lower_ptls!(mod::LLVM.Module)
423448
return changed
424449
end
425450
LowerPTLSPass() = NewPMModulePass("GPULowerPTLS", lower_ptls!)
451+
452+
453+
function callback_pass!(name, callback::F, mod::LLVM.Module) where F
454+
job = current_job::CompilerJob
455+
changed = false
456+
457+
if haskey(functions(mod), name)
458+
marker = functions(mod)[name]
459+
changed = callback(job, marker, mod)
460+
end
461+
return changed
462+
end
463+
464+
CallbackPass(name, callback) = NewPMModulePass("CallbackPass<$name>", (mod)->callback_pass!(name, callback, mod))

test/plugin_testsetup.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
@testsetup module Plugin
2+
3+
using Test
4+
using ReTestItems
5+
import LLVM
6+
import GPUCompiler
7+
8+
function mark(x)
9+
ccall("extern gpucompiler.mark", llvmcall, Nothing, (Int,), x)
10+
end
11+
12+
function remove_mark!(@nospecialize(job), intrinsic, mod::LLVM.Module)
13+
changed = false
14+
15+
for use in LLVM.uses(intrinsic)
16+
val = LLVM.user(use)
17+
if isempty(LLVM.uses(val))
18+
LLVM.erase!(val)
19+
changed = true
20+
else
21+
# the validator will detect this
22+
end
23+
end
24+
25+
return changed
26+
end
27+
28+
GPUCompiler.register_plugin!("gpucompiler.mark", false,
29+
pipeline_callback=remove_mark!)
30+
31+
end

test/ptx_tests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@testitem "PTX" setup=[PTX, Helpers] begin
22

33
using LLVM
4+
import InteractiveUtils
45

56
############################################################################################
67

@@ -406,7 +407,22 @@ precompile_test_harness("Inference caching") do load_path
406407
@test check_presence(identity_mi, token)
407408
end
408409
end
410+
end # testitem
409411

410412
############################################################################################
411413

414+
@testitem "PTX plugin" setup=[PTX, Plugin] begin
415+
416+
import InteractiveUtils
417+
418+
@testset "Pipeline callbacks" begin
419+
function kernel(x)
420+
Plugin.mark(x)
421+
return
422+
end
423+
ir = sprint(io->InteractiveUtils.code_llvm(io, kernel, Tuple{Int}))
424+
@test occursin("gpucompiler.mark", ir)
425+
ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Int}))
426+
@test !occursin("gpucompiler.mark", ir)
412427
end
428+
end #testitem

test/ptx_testsetup.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
using GPUCompiler
44

5-
65
# create a PTX-based test compiler, and generate reflection methods for it
76

87
include("runtime.jl")

0 commit comments

Comments
 (0)