Skip to content

Commit 623c4c6

Browse files
authored
Support NewPM in some places (#2710)
1 parent 53b2461 commit 623c4c6

File tree

3 files changed

+32
-134
lines changed

3 files changed

+32
-134
lines changed

src/compiler.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,9 +1263,12 @@ function nested_codegen!(
12631263
edges = edges::Vector{Any}
12641264
push!(edges, funcspec)
12651265

1266-
LLVM.ModulePassManager() do pm
1267-
API.AddPreserveNVVMPass!(pm, true) #=Begin=#
1268-
LLVM.run!(pm, otherMod)
1266+
LLVM.@dispose pb=LLVM.NewPMPassBuilder() begin
1267+
registerEnzymeAndPassPipeline!(pb)
1268+
LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm
1269+
LLVM.add!(mpm, PreserveNVVMPass())
1270+
end
1271+
LLVM.run!(pb, mod)
12691272
end
12701273

12711274
if DumpPreNestedCheck[]
@@ -2752,10 +2755,7 @@ function enzyme!(
27522755
for f in collect(functions(mod))
27532756
API.EnzymeFixupBatchedJuliaCallingConvention(f)
27542757
end
2755-
ModulePassManager() do pm
2756-
dce!(pm)
2757-
LLVM.run!(pm, mod)
2758-
end
2758+
run!(DCEPass(), mod)
27592759
fix_decayaddr!(mod)
27602760
adjointf = adjointf == nothing ? nothing : functions(mod)[adjointfname]
27612761
augmented_primalf =
@@ -4502,9 +4502,12 @@ function GPUCompiler.compile_unhooked(output::Symbol, job::CompilerJob{<:EnzymeT
45024502
permit_inlining!(f)
45034503
end
45044504

4505-
LLVM.ModulePassManager() do pm
4506-
API.AddPreserveNVVMPass!(pm, true) #=Begin=#
4507-
LLVM.run!(pm, mod)
4505+
LLVM.@dispose pb=LLVM.NewPMPassBuilder() begin
4506+
registerEnzymeAndPassPipeline!(pb)
4507+
LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm
4508+
LLVM.add!(mpm, PreserveNVVMPass())
4509+
end
4510+
LLVM.run!(pb, mod)
45084511
end
45094512

45104513
primalf = meta.entry
@@ -5164,10 +5167,7 @@ end
51645167
push!(toremove, name(f))
51655168
end
51665169
end
5167-
ModulePassManager() do pm
5168-
always_inliner!(pm)
5169-
LLVM.run!(pm, mod)
5170-
end
5170+
run!(AlwaysInlinerPass(), mod)
51715171
for fname in toremove
51725172
if haskey(functions(mod), fname)
51735173
f = functions(mod)[fname]
@@ -5186,10 +5186,14 @@ end
51865186
augmented_primalf = nothing
51875187
end
51885188

5189-
LLVM.ModulePassManager() do pm
5190-
API.AddPreserveNVVMPass!(pm, false) #=Begin=#
5191-
LLVM.run!(pm, mod)
5189+
LLVM.@dispose pb=LLVM.NewPMPassBuilder() begin
5190+
registerEnzymeAndPassPipeline!(pb)
5191+
LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm
5192+
LLVM.add!(mpm, PreserveNVVMEndPass())
5193+
end
5194+
LLVM.run!(pb, mod)
51925195
end
5196+
51935197
if !(primal_target isa GPUCompiler.NativeCompilerTarget)
51945198
mark_gpu_intrinsics!(primal_target, mod)
51955199
end

src/compiler/optimize.jl

Lines changed: 9 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,13 @@
1-
struct PipelineConfig
2-
Speedup::Cint
3-
Size::Cint
4-
lower_intrinsics::Cint
5-
dump_native::Cint
6-
external_use::Cint
7-
llvm_only::Cint
8-
always_inline::Cint
9-
enable_early_simplifications::Cint
10-
enable_early_optimizations::Cint
11-
enable_scalar_optimizations::Cint
12-
enable_loop_optimizations::Cint
13-
enable_vector_pipeline::Cint
14-
remove_ni::Cint
15-
cleanup::Cint
1+
function registerEnzymeAndPassPipeline!(pb::NewPMPassBuilder)
2+
enzyme_callback = cglobal((:registerEnzymeAndPassPipeline, API.libEnzyme))
3+
LLVM.API.LLVMPassBuilderExtensionsPushRegistrationCallbacks(pb.exts, enzyme_callback)
164
end
175

18-
const RunAttributor = Ref(true)
19-
20-
function pipeline_options(;
21-
lower_intrinsics::Bool = true,
22-
dump_native::Bool = false,
23-
external_use::Bool = false,
24-
llvm_only::Bool = false,
25-
always_inline::Bool = true,
26-
enable_early_simplifications::Bool = true,
27-
enable_early_optimizations::Bool = true,
28-
enable_scalar_optimizations::Bool = true,
29-
enable_loop_optimizations::Bool = true,
30-
enable_vector_pipeline::Bool = true,
31-
remove_ni::Bool = true,
32-
cleanup::Bool = true,
33-
Size::Cint = 0,
34-
Speedup::Cint = 3,
35-
)
36-
return PipelineConfig(
37-
Speedup,
38-
Size,
39-
lower_intrinsics,
40-
dump_native,
41-
external_use,
42-
llvm_only,
43-
always_inline,
44-
enable_early_simplifications,
45-
enable_early_optimizations,
46-
enable_scalar_optimizations,
47-
enable_loop_optimizations,
48-
enable_vector_pipeline,
49-
remove_ni,
50-
cleanup,
51-
)
52-
end
6+
LLVM.@function_pass "jl-inst-simplify" JLInstSimplifyPass
7+
LLVM.@module_pass "preserve-nvvm" PreserveNVVMPass
8+
LLVM.@module_pass "preserve-nvvm-end" PreserveNVVMEndPass
539

54-
function run_jl_pipeline(pm::ModulePassManager, tm::LLVM.TargetMachine; kwargs...)
55-
config = Ref(pipeline_options(; kwargs...))
56-
function jl_pipeline(m)
57-
@dispose pb = NewPMPassBuilder() begin
58-
add!(pb, NewPMModulePassManager()) do mpm
59-
@ccall jl_build_newpm_pipeline(
60-
mpm.ref::Ptr{Cvoid},
61-
pb.ref::Ptr{Cvoid},
62-
config::Ptr{PipelineConfig},
63-
)::Cvoid
64-
end
65-
LLVM.run!(mpm, m, tm)
66-
end
67-
return true
68-
end
69-
add!(pm, ModulePass("JLPipeline", jl_pipeline))
70-
end
10+
const RunAttributor = Ref(true)
7111

7212
@static if VERSION < v"1.11.0-DEV.428"
7313
else
@@ -215,22 +155,7 @@ function loop_optimizations_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachi
215155
loop_unswitch!(pm)
216156
end
217157
else
218-
run_jl_pipeline(
219-
pm,
220-
tm;
221-
lower_intrinsics = false,
222-
dump_native = false,
223-
external_use = false,
224-
llvm_only = false,
225-
always_inline = false,
226-
enable_early_simplifications = false,
227-
enable_early_optimizations = false,
228-
enable_scalar_optimizations = false,
229-
enable_loop_optimizations = true,
230-
enable_vector_pipeline = false,
231-
remove_ni = false,
232-
cleanup = false,
233-
)
158+
@assert false
234159
end
235160
end
236161

@@ -253,36 +178,7 @@ function more_loop_optimizations_tm!(pm::LLVM.ModulePassManager, tm::LLVM.Target
253178
loop_deletion!(pm)
254179
loop_unroll!(pm) # TODO: in Julia createSimpleLoopUnroll
255180
else
256-
# LowerSIMDLoopPass
257-
# LoopRotatePass [opt >= 2]
258-
# LICMPass
259-
# JuliaLICMPass
260-
# SimpleLoopUnswitchPass
261-
# LICMPass
262-
# JuliaLICMPass
263-
# IRCEPass
264-
# LoopInstSimplifyPass
265-
# - in ours this is instcombine with jlinstsimplify
266-
# LoopIdiomRecognizePass
267-
# IndVarSimplifyPass
268-
# LoopDeletionPass
269-
# LoopFullUnrollPass
270-
run_jl_pipeline(
271-
pm,
272-
tm;
273-
lower_intrinsics = false,
274-
dump_native = false,
275-
external_use = false,
276-
llvm_only = false,
277-
always_inline = false,
278-
enable_early_simplifications = false,
279-
enable_early_optimizations = false,
280-
enable_scalar_optimizations = false,
281-
enable_loop_optimizations = true,
282-
enable_vector_pipeline = false,
283-
remove_ni = false,
284-
cleanup = false,
285-
)
181+
@assert false
286182
end
287183
end
288184

src/llvm/transforms.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,10 +2401,8 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
24012401
# and including 12 (but fixed 13+), Attributor will incorrectly change functions that
24022402
# call code with undef to become unreachable, even when there exist other valid
24032403
# callsites. See: https://godbolt.org/z/9Y3Gv6q5M
2404-
ModulePassManager() do pm
2405-
global_dce!(pm)
2406-
LLVM.run!(pm, mod)
2407-
end
2404+
run!(GlobalDCEPass(), mod)
2405+
24082406
# Prevent dead-arg-elimination of functions which we may require args for in the derivative
24092407
funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[], vararg = true)
24102408
if LLVM.version().major <= 15

0 commit comments

Comments
 (0)