Skip to content

Commit e0c0d29

Browse files
committed
Rewrite optimize pipeline for NewPM
1 parent 9a83d04 commit e0c0d29

File tree

2 files changed

+126
-277
lines changed

2 files changed

+126
-277
lines changed

src/compiler/optimize.jl

Lines changed: 102 additions & 256 deletions
Original file line numberDiff line numberDiff line change
@@ -22,277 +22,123 @@ end
2222
EnzymeAttributorPass() = NewPMModulePass("enzyme_attributor", enzyme_attributor_pass!)
2323
ReinsertGCMarkerPass() = NewPMFunctionPass("reinsert_gcmarker", reinsert_gcmarker_pass!)
2424
SafeAtomicToRegularStorePass() = NewPMFunctionPass("safe_atomic_to_regular_store", safe_atomic_to_regular_store!)
25+
Addr13NoAliasPass() = NewPMModulePass("addr13_noalias", addr13NoAlias)
26+
RewriteGenericMemoryPass() = NewPMModulePass("rewrite_generic_memory", rewrite_generic_memory!)
2527

26-
@static if VERSION < v"1.11.0-DEV.428"
27-
else
28-
barrier_noop!(pm) = nothing
29-
end
30-
31-
@static if VERSION < v"1.11-"
32-
function gc_invariant_verifier_tm!(pm::ModulePassManager, tm::LLVM.TargetMachine, cond::Bool)
33-
gc_invariant_verifier!(pm, cond)
34-
end
35-
else
36-
function gc_invariant_verifier_tm!(pm::ModulePassManager, tm::LLVM.TargetMachine, cond::Bool)
37-
function gc_invariant_verifier(mod::LLVM.Module)
38-
@dispose pb = NewPMPassBuilder() begin
39-
add!(pb, NewPMModulePassManager()) do mpm
40-
add!(mpm, NewPMFunctionPassManager()) do fpm
41-
add!(fpm, GCInvariantVerifierPass(; strong = cond))
42-
end
43-
end
44-
run!(pb, mod)
45-
end
46-
return true
28+
function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
29+
@dispose pb = NewPMPassBuilder() begin
30+
registerEnzymeAndPassPipeline!(pb)
31+
register!(pb, Addr13NoAliasPass())
32+
register!(pb, RewriteGenericMemoryPass())
33+
add!(pb, NewPMAAManager()) do aam
34+
add!(aam, ScopedNoAliasAA())
35+
add!(aam, TypeBasedAA())
36+
add!(aam, BasicAA())
4737
end
48-
add!(pm, ModulePass("GCInvariantVerifier", gc_invariant_verifier))
49-
end
50-
end
38+
add!(pb, NewPMModulePassManager()) do mpm
39+
add!(mpm, Addr13NoAliasPass())
5140

52-
@static if VERSION < v"1.11-"
53-
function propagate_julia_addrsp_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine)
54-
propagate_julia_addrsp!(pm)
55-
end
56-
else
57-
function propagate_julia_addrsp_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine)
58-
function prop_julia_addr(mod::LLVM.Module)
59-
@dispose pb = NewPMPassBuilder() begin
60-
add!(pb, NewPMModulePassManager()) do mpm
61-
add!(mpm, NewPMFunctionPassManager()) do fpm
62-
add!(fpm, PropagateJuliaAddrspacesPass())
63-
end
64-
end
65-
run!(pb, mod)
41+
add!(mpm, NewPMFunctionPassManager()) do fpm
42+
add!(fpm, PropagateJuliaAddrspacesPass())
43+
add!(fpm, SimplifyCFGPass())
44+
add!(fpm, DCEPass())
45+
end
46+
add!(mpm, CPUFeaturesPass())
47+
add!(mpm, NewPMFunctionPassManager()) do fpm
48+
add!(fpm, SROAPass())
49+
add!(fpm, MemCpyOptPass())
50+
end
51+
add!(mpm, AlwaysInlinerPass())
52+
add!(mpm, NewPMFunctionPassManager()) do fpm
53+
add!(fpm, AllocOptPass())
54+
end
55+
56+
add!(mpm, GlobalOptPass())
57+
add!(mpm, NewPMFunctionPassManager()) do fpm
58+
add!(fpm, GVNPass())
6659
end
67-
return true
68-
end
69-
add!(pm, ModulePass("PropagateJuliaAddrSpace", prop_julia_addr))
70-
end
71-
end
7260

73-
@static if VERSION < v"1.11-"
74-
function alloc_opt_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine)
75-
alloc_opt!(pm)
76-
end
77-
else
78-
function alloc_opt_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine)
79-
function alloc_opt(mod::LLVM.Module)
80-
@dispose pb = NewPMPassBuilder() begin
81-
add!(pb, NewPMModulePassManager()) do mpm
82-
add!(mpm, NewPMFunctionPassManager()) do fpm
83-
add!(fpm, AllocOptPass())
84-
end
61+
add!(mpm, RewriteGenericMemoryPass())
62+
63+
add!(mpm, NewPMFunctionPassManager()) do fpm
64+
add!(fpm, InstCombinePass())
65+
add!(fpm, JLInstSimplifyPass())
66+
add!(fpm, SimplifyCFGPass())
67+
add!(fpm, SROAPass())
68+
add!(fpm, InstCombinePass())
69+
add!(fpm, JLInstSimplifyPass())
70+
add!(fpm, JumpThreadingPass())
71+
add!(fpm, CorrelatedValuePropagationPass())
72+
add!(fpm, InstCombinePass())
73+
add!(fpm, JLInstSimplifyPass())
74+
add!(fpm, ReassociatePass())
75+
add!(fpm, EarlyCSEPass())
76+
add!(fpm, AllocOptPass())
77+
add!(fpm, NewPMLoopPassManager(use_memory_ssa=true)) do lpm
78+
add!(lpm, LoopIdiomRecognizePass())
79+
add!(lpm, LoopRotatePass())
80+
add!(lpm, LowerSIMDLoopPass())
81+
add!(lpm, LICMPass())
82+
add!(lpm, JuliaLICMPass())
83+
add!(lpm, SimpleLoopUnswitchPass())
8584
end
86-
run!(pb, mod)
87-
end
88-
return true
89-
end
90-
add!(pm, ModulePass("AllocOpt", alloc_opt))
91-
end
92-
end
9385

94-
@static if VERSION < v"1.11-"
95-
function lower_simdloop_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine)
96-
lower_simdloop!(pm)
97-
end
98-
else
99-
function lower_simdloop_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine)
100-
function lower_simdloop(mod::LLVM.Module)
101-
@dispose pb = NewPMPassBuilder() begin
102-
add!(pb, NewPMModulePassManager()) do mpm
103-
add!(mpm, NewPMFunctionPassManager()) do fpm
104-
add!(fpm, NewPMLoopPassManager()) do lpm
105-
add!(lpm, LowerSIMDLoopPass())
106-
end
107-
end
86+
add!(fpm, InstCombinePass())
87+
add!(fpm, JLInstSimplifyPass())
88+
add!(fpm, NewPMLoopPassManager()) do lpm
89+
add!(lpm, IndVarSimplifyPass())
90+
add!(lpm, LoopDeletionPass())
91+
end
92+
add!(fpm, LoopUnrollPass(opt_level=2))
93+
add!(fpm, AllocOptPass())
94+
add!(fpm, SROAPass())
95+
add!(fpm, GVNPass())
96+
97+
# This InstCombine needs to be after GVN
98+
# Otherwise it will generate load chains in GPU code...
99+
add!(fpm, InstCombinePass())
100+
add!(fpm, JLInstSimplifyPass())
101+
add!(fpm, MemCpyOptPass())
102+
add!(fpm, SCCPPass())
103+
add!(fpm, InstCombinePass())
104+
add!(fpm, JLInstSimplifyPass())
105+
add!(fpm, JumpThreadingPass())
106+
add!(fpm, DSEPass())
107+
add!(fpm, AllocOptPass())
108+
add!(fpm, SimplifyCFGPass())
109+
110+
111+
add!(fpm, NewPMLoopPassManager()) do lpm
112+
add!(lpm, LoopIdiomRecognizePass())
113+
add!(lpm, LoopDeletionPass())
108114
end
109-
run!(pb, mod)
115+
add!(fpm, JumpThreadingPass())
116+
add!(fpm, CorrelatedValuePropagationPass())
117+
118+
add!(fpm, ADCEPass())
119+
add!(fpm, InstCombinePass())
120+
add!(fpm, JLInstSimplifyPass())
121+
122+
# GC passes
123+
add!(fpm, GCInvariantVerifierPass(strong=false))
124+
add!(fpm, SimplifyCFGPass())
125+
add!(fpm, InstCombinePass())
126+
add!(fpm, JLInstSimplifyPass())
110127
end
111-
return true
112-
end
113-
# really looppass
114-
add!(pm, ModulePass("LowerSIMDLoop", lower_simdloop))
115-
end
116-
end
117-
118-
function loop_optimizations_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine)
119-
lower_simdloop_tm!(pm, tm)
120-
licm!(pm)
121-
if LLVM.version() >= v"15"
122-
simple_loop_unswitch_legacy!(pm)
123-
else
124-
loop_unswitch!(pm)
125-
end
126-
end
127128

128-
@static if VERSION < v"1.11-"
129-
function cpu_features_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine)
130-
@static if isdefined(LLVM.Interop, :cpu_features!)
131-
LLVM.Interop.cpu_features!(pm)
132-
else
133-
@static if isdefined(GPUCompiler, :cpu_features!)
134-
GPUCompiler.cpu_features!(pm)
129+
add!(mpm, GlobalOptPass())
130+
add!(mpm, NewPMFunctionPassManager()) do fpm
131+
add!(fpm, GVNPass())
135132
end
136133
end
137-
end
138-
else
139-
function cpu_features_tm!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachine)
140-
function cpu_features(mod)
141-
@dispose pb = NewPMPassBuilder() begin
142-
add!(pb, NewPMModulePassManager()) do mpm
143-
add!(mpm, CPUFeaturesPass())
144-
end
145-
run!(pb, mod)
146-
end
147-
return true
148-
end
149-
add!(pm, ModulePass("CPUFeatures", cpu_features))
150-
end
151-
end
152134

153-
function jl_inst_simplify!(PM::LLVM.ModulePassManager)
154-
ccall(
155-
(:LLVMAddJLInstSimplifyPass, API.libEnzyme),
156-
Cvoid,
157-
(LLVM.API.LLVMPassManagerRef,),
158-
PM,
159-
)
160-
end
161-
162-
cse!(pm) = LLVM.API.LLVMAddEarlyCSEPass(pm)
163-
164-
function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
165-
addr13NoAlias(mod)
166-
# everying except unroll, slpvec, loop-vec
167-
# then finish Julia GC
168-
ModulePassManager() do pm
169-
add_library_info!(pm, triple(mod))
170-
add_transform_info!(pm, tm)
171-
172-
propagate_julia_addrsp_tm!(pm, tm)
173-
scoped_no_alias_aa!(pm)
174-
type_based_alias_analysis!(pm)
175-
basic_alias_analysis!(pm)
176-
cfgsimplification!(pm)
177-
dce!(pm)
178-
cpu_features_tm!(pm, tm)
179-
scalar_repl_aggregates_ssa!(pm) # SSA variant?
180-
mem_cpy_opt!(pm)
181-
always_inliner!(pm)
182-
alloc_opt_tm!(pm, tm)
183-
LLVM.run!(pm, mod)
184-
end
185-
186-
# Globalopt is separated as it can delete functions, which invalidates the Julia hardcoded pointers to
187-
# known functions
188-
ModulePassManager() do pm
189-
190-
add_library_info!(pm, triple(mod))
191-
add_transform_info!(pm, tm)
192-
193-
scoped_no_alias_aa!(pm)
194-
type_based_alias_analysis!(pm)
195-
basic_alias_analysis!(pm)
196-
cpu_features_tm!(pm, tm)
197-
198-
LLVM.API.LLVMAddGlobalOptimizerPass(pm) # Extra
199-
gvn!(pm) # Extra
200-
LLVM.run!(pm, mod)
201-
end
135+
run!(pb, mod, tm)
202136

203-
rewrite_generic_memory!(mod)
204-
205-
ModulePassManager() do pm
206-
add_library_info!(pm, triple(mod))
207-
add_transform_info!(pm, tm)
208-
209-
scoped_no_alias_aa!(pm)
210-
type_based_alias_analysis!(pm)
211-
basic_alias_analysis!(pm)
212-
cpu_features_tm!(pm, tm)
213-
214-
instruction_combining!(pm)
215-
jl_inst_simplify!(pm)
216-
cfgsimplification!(pm)
217-
scalar_repl_aggregates_ssa!(pm) # SSA variant?
218-
instruction_combining!(pm)
219-
jl_inst_simplify!(pm)
220-
jump_threading!(pm)
221-
correlated_value_propagation!(pm)
222-
instruction_combining!(pm)
223-
jl_inst_simplify!(pm)
224-
reassociate!(pm)
225-
early_cse!(pm)
226-
alloc_opt_tm!(pm, tm)
227-
loop_idiom!(pm)
228-
loop_rotate!(pm)
229-
230-
loop_optimizations_tm!(pm, tm)
231-
232-
instruction_combining!(pm)
233-
jl_inst_simplify!(pm)
234-
ind_var_simplify!(pm)
235-
loop_deletion!(pm)
236-
loop_unroll!(pm)
237-
alloc_opt_tm!(pm, tm)
238-
scalar_repl_aggregates_ssa!(pm) # SSA variant?
239-
gvn!(pm)
240-
241-
# This InstCombine needs to be after GVN
242-
# Otherwise it will generate load chains in GPU code...
243-
instruction_combining!(pm)
244-
jl_inst_simplify!(pm)
245-
mem_cpy_opt!(pm)
246-
sccp!(pm)
247-
instruction_combining!(pm)
248-
jl_inst_simplify!(pm)
249-
jump_threading!(pm)
250-
dead_store_elimination!(pm)
251-
alloc_opt_tm!(pm, tm)
252-
cfgsimplification!(pm)
253-
loop_idiom!(pm)
254-
loop_deletion!(pm)
255-
jump_threading!(pm)
256-
correlated_value_propagation!(pm)
257-
# SLP_Vectorizer -- not for Enzyme
258-
259-
LLVM.run!(pm, mod)
260-
261-
aggressive_dce!(pm)
262-
instruction_combining!(pm)
263-
jl_inst_simplify!(pm)
264-
# Loop Vectorize -- not for Enzyme
265-
# InstCombine
266-
267-
# GC passes
268-
barrier_noop!(pm)
269-
gc_invariant_verifier_tm!(pm, tm, false)
270-
271-
# FIXME: Currently crashes printing
272-
cfgsimplification!(pm)
273-
instruction_combining!(pm) # Extra for Enzyme
274-
jl_inst_simplify!(pm)
275-
LLVM.run!(pm, mod)
276-
end
277-
278-
# Globalopt is separated as it can delete functions, which invalidates the Julia hardcoded pointers to
279-
# known functions
280-
ModulePassManager() do pm
281-
add_library_info!(pm, triple(mod))
282-
add_transform_info!(pm, tm)
283-
284-
scoped_no_alias_aa!(pm)
285-
type_based_alias_analysis!(pm)
286-
basic_alias_analysis!(pm)
287-
cpu_features_tm!(pm, tm)
288-
289-
LLVM.API.LLVMAddGlobalOptimizerPass(pm) # Exxtra
290-
gvn!(pm) # Exxtra
291-
LLVM.run!(pm, mod)
137+
# TODO: Turn into passes?
138+
removeDeadArgs!(mod, tm)
139+
detect_writeonly!(mod)
140+
nodecayed_phis!(mod)
292141
end
293-
removeDeadArgs!(mod, tm)
294-
detect_writeonly!(mod)
295-
nodecayed_phis!(mod)
296142
end
297143

298144
function addOptimizationPasses!(mpm::LLVM.NewPMPassManager)

0 commit comments

Comments
 (0)