Skip to content

Commit 40e6ad9

Browse files
committed
Add optimization callbacks that fire on a marker function
1 parent 5281e86 commit 40e6ad9

File tree

4 files changed

+77
-2
lines changed

4 files changed

+77
-2
lines changed

src/optim.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=1)
44
tm = llvm_machine(job.config.target)
55

6-
global current_job
6+
global current_job # ScopedValue?
77
current_job = job
88

99
@dispose pb=NewPMPassBuilder() begin
@@ -14,6 +14,10 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
1414
register!(pb, LowerKernelStatePass())
1515
register!(pb, CleanupKernelStatePass())
1616

17+
for (name, callback) in PIPELINE_CALLBACKS
18+
register!(pb, CallbackPass(name, callback))
19+
end
20+
1721
add!(pb, NewPMModulePassManager()) do mpm
1822
buildNewPMPipeline!(mpm, job, opt_level)
1923
end
@@ -24,6 +28,15 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
2428
return
2529
end
2630

31+
# TODO: Priority heap to provide order between different plugins
32+
const PIPELINE_CALLBACKS = Dict{String, Any}()
33+
function register_plugin!(name::String, plugin)
34+
if haskey(PIPELINE_CALLBACKS, name)
35+
error("GPUCompiler plugin with name $name is already registered")
36+
end
37+
PIPELINE_CALLBACKS[name] = plugin
38+
end
39+
2740
function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level)
2841
buildEarlySimplificationPipeline(mpm, job, opt_level)
2942
add!(mpm, AlwaysInlinerPass())
@@ -41,6 +54,9 @@ function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level)
4154
add!(fpm, WarnMissedTransformationsPass())
4255
end
4356
end
57+
for (name, callback) in PIPELINE_CALLBACKS
58+
add!(mpm, CallbackPass(name, callback))
59+
end
4460
buildIntrinsicLoweringPipeline(mpm, job, opt_level)
4561
buildCleanupPipeline(mpm, job, opt_level)
4662
end
@@ -423,3 +439,17 @@ function lower_ptls!(mod::LLVM.Module)
423439
return changed
424440
end
425441
LowerPTLSPass() = NewPMModulePass("GPULowerPTLS", lower_ptls!)
442+
443+
444+
function callback_pass!(name, callback::F, mod::LLVM.Module) where F
445+
job = current_job::CompilerJob
446+
changed = false
447+
448+
if haskey(functions(mod), name)
449+
marker = functions(mod)[name]
450+
changed = callback(job, marker, mod)
451+
end
452+
return changed
453+
end
454+
455+
CallbackPass(name, callback) = NewPMModulePass("CallbackPass<$name>", (mod)->callback_pass!(name, callback, mod))

test/plugin_testsetup.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
@testsetup module Plugin
2+
3+
using Test
4+
using ReTestItems
5+
import LLVM
6+
import GPUCompiler
7+
8+
function mark(x)
9+
ccall("extern gpucompiler.mark", llvmcall, Nothing, (Int,), x)
10+
end
11+
12+
function remove_mark!(@nospecialize(job), intrinsic, mod::LLVM.Module)
13+
changed = false
14+
15+
for use in LLVM.uses(intrinsic)
16+
val = LLVM.user(use)
17+
if isempty(LLVM.uses(val))
18+
LLVM.erase!(val)
19+
changed = true
20+
else
21+
# the validator will detect this
22+
end
23+
end
24+
25+
return changed
26+
end
27+
28+
GPUCompiler.register_plugin!("gpucompiler.mark", remove_mark!)
29+
30+
end

test/ptx_tests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@testitem "PTX" setup=[PTX, Helpers] begin
22

33
using LLVM
4+
import InteractiveUtils
45

56
############################################################################################
67

@@ -406,7 +407,22 @@ precompile_test_harness("Inference caching") do load_path
406407
@test check_presence(identity_mi, token)
407408
end
408409
end
410+
end # testitem
409411

410412
############################################################################################
411413

414+
@testitem "PTX plugin" setup=[PTX, Plugin] begin
415+
416+
import InteractiveUtils
417+
418+
@testset "Pipeline callbacks" begin
419+
function kernel(x)
420+
Plugin.mark(x)
421+
return
422+
end
423+
ir = sprint(io->InteractiveUtils.code_llvm(io, kernel, Tuple{Int}))
424+
@test occursin("gpucompiler.mark", ir)
425+
ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Int}))
426+
@test !occursin("gpucompiler.mark", ir)
412427
end
428+
end #testitem

test/ptx_testsetup.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
using GPUCompiler
44

5-
65
# create a PTX-based test compiler, and generate reflection methods for it
76

87
include("runtime.jl")

0 commit comments

Comments
 (0)