Skip to content

Commit 53f3cfd

Browse files
gbaraldimaleadt
authored andcommitted
Implement callback invoking on the NewPM API
1 parent 90917c8 commit 53f3cfd

File tree

6 files changed

+300
-4
lines changed

6 files changed

+300
-4
lines changed

deps/LLVMExtra/include/LLVMExtra.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ typedef struct LLVMOpaquePassBuilderExtensions *LLVMPassBuilderExtensionsRef;
232232
LLVMPassBuilderExtensionsRef LLVMCreatePassBuilderExtensions(void);
233233
void LLVMDisposePassBuilderExtensions(LLVMPassBuilderExtensionsRef Extensions);
234234
void LLVMPassBuilderExtensionsPushRegistrationCallbacks(LLVMPassBuilderExtensionsRef Options,
235-
void (*RegistrationCallback)(void *));
235+
void (*RegistrationCallback)(void *));
236236
typedef LLVMBool (*LLVMJuliaModulePassCallback)(LLVMModuleRef M, void *Thunk);
237237
typedef LLVMBool (*LLVMJuliaFunctionPassCallback)(LLVMValueRef F, void *Thunk);
238238
void LLVMPassBuilderExtensionsRegisterModulePass(LLVMPassBuilderExtensionsRef Options,

deps/LLVMExtra/lib/NewPM.cpp

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <llvm/Passes/PassBuilder.h>
77
#include <llvm/Passes/StandardInstrumentations.h>
88
#include <llvm/Support/CBindingWrapping.h>
9+
#include <optional>
910

1011
using namespace llvm;
1112

@@ -143,6 +144,167 @@ void LLVMPassBuilderExtensionsSetAAPipeline(LLVMPassBuilderExtensionsRef Extensi
143144
}
144145
#endif
145146

147+
static bool checkParametrizedPassName(StringRef Name, StringRef PassName) {
148+
if (!Name.consume_front(PassName))
149+
return false;
150+
// normal pass name w/o parameters == default parameters
151+
if (Name.empty())
152+
return true;
153+
#if LLVM_VERSION_MAJOR >= 16
154+
return Name.starts_with("<") && Name.ends_with(">");
155+
#else
156+
return Name.startswith("<") && Name.endswith(">");
157+
#endif
158+
}
159+
160+
static std::optional<OptimizationLevel> parseOptLevel(StringRef S) {
161+
return StringSwitch<std::optional<OptimizationLevel>>(S)
162+
.Case("O0", OptimizationLevel::O0)
163+
.Case("O1", OptimizationLevel::O1)
164+
.Case("O2", OptimizationLevel::O2)
165+
.Case("O3", OptimizationLevel::O3)
166+
.Case("Os", OptimizationLevel::Os)
167+
.Case("Oz", OptimizationLevel::Oz)
168+
.Default(std::nullopt);
169+
}
170+
171+
static Expected<OptimizationLevel> parseOptLevelParam(StringRef S) {
172+
std::optional<OptimizationLevel> OptLevel = parseOptLevel(S);
173+
if (OptLevel)
174+
return *OptLevel;
175+
return make_error<StringError>(
176+
formatv("invalid optimization level '{}'", S).str(),
177+
inconvertibleErrorCode());
178+
}
179+
180+
template <typename ParametersParseCallableT>
181+
static auto parsePassParameters(ParametersParseCallableT &&Parser,
182+
StringRef Name, StringRef PassName)
183+
-> decltype(Parser(StringRef{})) {
184+
using ParametersT = typename decltype(Parser(StringRef{}))::value_type;
185+
186+
StringRef Params = Name;
187+
if (!Params.consume_front(PassName)) {
188+
llvm_unreachable(
189+
"unable to strip pass name from parametrized pass specification");
190+
}
191+
if (!Params.empty() &&
192+
(!Params.consume_front("<") || !Params.consume_back(">"))) {
193+
llvm_unreachable("invalid format for parametrized pass name");
194+
}
195+
196+
Expected<ParametersT> Result = Parser(Params);
197+
assert((Result || Result.template errorIsA<StringError>()) &&
198+
"Pass parameter parser can only return StringErrors.");
199+
return Result;
200+
}
201+
202+
203+
// Register target specific parsing callbacks
204+
static void registerCallbackParsing(PassBuilder &PB) {
205+
PB.registerPipelineParsingCallback(
206+
[&](StringRef Name, ModulePassManager &PM,
207+
ArrayRef<PassBuilder::PipelineElement>) {
208+
#define MODULE_CALLBACK(NAME, INVOKE) \
209+
if (checkParametrizedPassName(Name, NAME)) { \
210+
auto L = parsePassParameters(parseOptLevelParam, Name, NAME); \
211+
if (!L) { \
212+
errs() << NAME ": " << toString(L.takeError()) << '\n'; \
213+
return false; \
214+
} \
215+
PB.INVOKE(PM, L.get()); \
216+
return true; \
217+
}
218+
#include "callbacks.inc"
219+
return false;
220+
});
221+
222+
// Module-level callbacks with LTO phase (use Phase::None for string API)
223+
PB.registerPipelineParsingCallback(
224+
[&](StringRef Name, ModulePassManager &PM,
225+
ArrayRef<PassBuilder::PipelineElement>) {
226+
#if LLVM_VERSION_MAJOR > 20
227+
#define MODULE_LTO_CALLBACK(NAME, INVOKE) \
228+
if (checkParametrizedPassName(Name, NAME)) { \
229+
auto L = parsePassParameters(parseOptLevelParam, Name, NAME); \
230+
if (!L) { \
231+
errs() << NAME ": " << toString(L.takeError()) << '\n'; \
232+
return false; \
233+
} \
234+
PB.INVOKE(PM, L.get(), ThinOrFullLTOPhase::None); \
235+
return true; \
236+
}
237+
#include "callbacks.inc"
238+
#else
239+
#define MODULE_LTO_CALLBACK(NAME, INVOKE) \
240+
if (checkParametrizedPassName(Name, NAME)) { \
241+
auto L = parsePassParameters(parseOptLevelParam, Name, NAME); \
242+
if (!L) { \
243+
errs() << NAME ": " << toString(L.takeError()) << '\n'; \
244+
return false; \
245+
} \
246+
PB.INVOKE(PM, L.get()); \
247+
return true; \
248+
}
249+
#include "callbacks.inc"
250+
#endif
251+
return false;
252+
});
253+
254+
// Function-level callbacks
255+
PB.registerPipelineParsingCallback(
256+
[&](StringRef Name, FunctionPassManager &PM,
257+
ArrayRef<PassBuilder::PipelineElement>) {
258+
#define FUNCTION_CALLBACK(NAME, INVOKE) \
259+
if (checkParametrizedPassName(Name, NAME)) { \
260+
auto L = parsePassParameters(parseOptLevelParam, Name, NAME); \
261+
if (!L) { \
262+
errs() << NAME ": " << toString(L.takeError()) << '\n'; \
263+
return false; \
264+
} \
265+
PB.INVOKE(PM, L.get()); \
266+
return true; \
267+
}
268+
#include "callbacks.inc"
269+
return false;
270+
});
271+
272+
// CGSCC-level callbacks
273+
PB.registerPipelineParsingCallback(
274+
[&](StringRef Name, CGSCCPassManager &PM,
275+
ArrayRef<PassBuilder::PipelineElement>) {
276+
#define CGSCC_CALLBACK(NAME, INVOKE) \
277+
if (checkParametrizedPassName(Name, NAME)) { \
278+
auto L = parsePassParameters(parseOptLevelParam, Name, NAME); \
279+
if (!L) { \
280+
errs() << NAME ": " << toString(L.takeError()) << '\n'; \
281+
return false; \
282+
} \
283+
PB.INVOKE(PM, L.get()); \
284+
return true; \
285+
}
286+
#include "callbacks.inc"
287+
return false;
288+
});
289+
290+
// Loop-level callbacks
291+
PB.registerPipelineParsingCallback(
292+
[&](StringRef Name, LoopPassManager &PM,
293+
ArrayRef<PassBuilder::PipelineElement>) {
294+
#define LOOP_CALLBACK(NAME, INVOKE) \
295+
if (checkParametrizedPassName(Name, NAME)) { \
296+
auto L = parsePassParameters(parseOptLevelParam, Name, NAME); \
297+
if (!L) { \
298+
errs() << NAME ": " << toString(L.takeError()) << '\n'; \
299+
return false; \
300+
} \
301+
PB.INVOKE(PM, L.get()); \
302+
return true; \
303+
}
304+
#include "callbacks.inc"
305+
return false;
306+
});
307+
}
146308

147309
// Vendored API entrypoint
148310

@@ -165,7 +327,7 @@ static LLVMErrorRef runJuliaPasses(Module *Mod, Function *Fun, const char *Passe
165327
PB.registerPipelineParsingCallback(Callback);
166328
for (auto &Callback : PassExts->FunctionPipelineParsingCallbacks)
167329
PB.registerPipelineParsingCallback(Callback);
168-
330+
registerCallbackParsing(PB); // Parsing for target-specific callbacks
169331
LoopAnalysisManager LAM;
170332
FunctionAnalysisManager FAM;
171333
CGSCCAnalysisManager CGAM;

deps/LLVMExtra/lib/callbacks.inc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#ifdef MODULE_CALLBACK
2+
MODULE_CALLBACK("pipeline-start-callbacks", invokePipelineStartEPCallbacks)
3+
#endif
4+
#undef MODULE_CALLBACK
5+
6+
// There are some full lto specific ones that are ignored here for now
7+
#ifdef MODULE_LTO_CALLBACK
8+
MODULE_LTO_CALLBACK("pipeline-early-simplification-callbacks", invokePipelineEarlySimplificationEPCallbacks)
9+
MODULE_LTO_CALLBACK("optimizer-early-callbacks", invokeOptimizerEarlyEPCallbacks)
10+
MODULE_LTO_CALLBACK("optimizer-last-callbacks", invokeOptimizerLastEPCallbacks)
11+
#endif
12+
#undef MODULE_LTO_CALLBACK
13+
14+
#ifdef FUNCTION_CALLBACK
15+
FUNCTION_CALLBACK("peephole-callbacks", invokePeepholeEPCallbacks)
16+
FUNCTION_CALLBACK("scalar-optimizer-late-callbacks", invokeScalarOptimizerLateEPCallbacks)
17+
FUNCTION_CALLBACK("vectorizer-start-callbacks", invokeVectorizerStartEPCallbacks)
18+
#if LLVM_VERSION_MAJOR >= 21
19+
FUNCTION_CALLBACK("vectorizer-end-callbacks", invokeVectorizerEndEPCallbacks)
20+
#endif
21+
#endif
22+
#undef FUNCTION_CALLBACK
23+
24+
#ifdef CGSCC_CALLBACK
25+
CGSCC_CALLBACK("cgscc-optimizer-late-callbacks", invokeCGSCCOptimizerLateEPCallbacks)
26+
#endif
27+
#undef CGSCC_CALLBACK
28+
29+
#ifdef LOOP_CALLBACK
30+
LOOP_CALLBACK("late-loop-optimizations-callbacks", invokeLateLoopOptimizationsEPCallbacks)
31+
LOOP_CALLBACK("loop-optimizer-end-callbacks", invokeLoopOptimizerEndEPCallbacks)
32+
#endif
33+
#undef LOOP_CALLBACK

deps/build_local.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ LLVM_DIR = joinpath(LLVM.artifact_dir, "lib", "cmake", "llvm")
4747

4848
# build and install
4949
@info "Building" source_dir scratch_dir build_dir LLVM_DIR
50-
cmake() do cmake_path
50+
cmake(adjust_LIBPATH=false) do cmake_path
5151
config_opts = `-DLLVM_ROOT=$(LLVM_DIR) -DCMAKE_INSTALL_PREFIX=$(scratch_dir)`
5252
if Sys.iswindows()
5353
# prevent picking up MSVC

src/newpm.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,29 @@ function InternalizePass(; preserved_gvs::Vector=String[], kwargs...)
514514

515515
"internalize" * kwargs_to_params(kwargs)
516516
end
517+
# Module callbacks (not part of general pass sweep)
518+
export PipelineStartCallbacks, PipelineEarlySimplificationCallbacks,
519+
OptimizerEarlyCallbacks, OptimizerLastCallbacks
520+
function PipelineStartCallbacks(; opt_level=0)
521+
opts = Dict{Symbol,Any}()
522+
opts[Symbol("O$opt_level")] = true
523+
"pipeline-start-callbacks" * kwargs_to_params(opts)
524+
end
525+
function PipelineEarlySimplificationCallbacks(; opt_level=0)
526+
opts = Dict{Symbol,Any}()
527+
opts[Symbol("O$opt_level")] = true
528+
"pipeline-early-simplification-callbacks" * kwargs_to_params(opts)
529+
end
530+
function OptimizerEarlyCallbacks(; opt_level=0)
531+
opts = Dict{Symbol,Any}()
532+
opts[Symbol("O$opt_level")] = true
533+
"optimizer-early-callbacks" * kwargs_to_params(opts)
534+
end
535+
function OptimizerLastCallbacks(; opt_level=0)
536+
opts = Dict{Symbol,Any}()
537+
opts[Symbol("O$opt_level")] = true
538+
"optimizer-last-callbacks" * kwargs_to_params(opts)
539+
end
517540

518541
# CGSCC passes
519542

@@ -526,6 +549,14 @@ end
526549
@cgscc_pass "inline" InlinerPass
527550
@cgscc_pass "coro-split" CoroSplitPass
528551

552+
#CGSCC callbacks (not part of general pass sweep)
553+
export CGSCCOptimizerLateCallbacks
554+
function CGSCCOptimizerLateCallbacks(; opt_level=0)
555+
opts = Dict{Symbol,Any}()
556+
opts[Symbol("O$opt_level")] = true
557+
"cgscc-optimizer-late-callbacks" * kwargs_to_params(opts)
558+
end
559+
529560
# function passes
530561

531562
@function_pass "aa-eval" AAEvaluator
@@ -709,6 +740,31 @@ end
709740
@function_pass "gvn" GVNPass
710741
@function_pass "print<stack-lifetime>" StackLifetimePrinterPass
711742

743+
# Function pass callbacks (not part of general pass sweep)
744+
export PeepholeCallbacks, ScalarOptimizerLateCallbacks, VectorizerStartCallbacks
745+
function PeepholeCallbacks(; opt_level=0)
746+
opts = Dict{Symbol,Any}()
747+
opts[Symbol("O$opt_level")] = true
748+
"peephole-callbacks" * kwargs_to_params(opts)
749+
end
750+
function ScalarOptimizerLateCallbacks(; opt_level=0)
751+
opts = Dict{Symbol,Any}()
752+
opts[Symbol("O$opt_level")] = true
753+
"scalar-optimizer-late-callbacks" * kwargs_to_params(opts)
754+
end
755+
function VectorizerStartCallbacks(; opt_level=0)
756+
opts = Dict{Symbol,Any}()
757+
opts[Symbol("O$opt_level")] = true
758+
"vectorizer-start-callbacks" * kwargs_to_params(opts)
759+
end
760+
@static if version() >= v"21"
761+
export VectorizerEndCallbacks
762+
function VectorizerEndCallbacks(; opt_level=0)
763+
opts = Dict{Symbol,Any}()
764+
opts[Symbol("O$opt_level")] = true
765+
"vectorizer-end-callbacks" * kwargs_to_params(opts)
766+
end
767+
end
712768
# loop nest passes
713769

714770
@loop_pass "loop-flatten" LoopFlattenPass
@@ -746,6 +802,19 @@ end
746802
@loop_pass "licm" LICMPass
747803
@loop_pass "lnicm" LNICMPass
748804

805+
# Loop Callbacks (not part of general pass sweep)
806+
export LateLoopOptimizationsCallbacks, LoopOptimizerEndCallbacks
807+
function LateLoopOptimizationsCallbacks(; opt_level=0)
808+
opts = Dict{Symbol,Any}()
809+
opts[Symbol("O$opt_level")] = true
810+
"late-loop-optimizations-callbacks" * kwargs_to_params(opts)
811+
end
812+
function LoopOptimizerEndCallbacks(; opt_level=0)
813+
opts = Dict{Symbol,Any}()
814+
opts[Symbol("O$opt_level")] = true
815+
"loop-optimizer-end-callbacks" * kwargs_to_params(opts)
816+
end
817+
749818

750819
## alias analyses
751820

test/newpm.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,12 @@ end
195195
end
196196

197197
@testset "loop" begin
198-
test_passes("loop", LLVM.loop_passes)
198+
# skip opt-level callback pseudo-passes, they require parameters and are provided as functions
199+
skip_loop = [
200+
"late-loop-optimizations-callbacks",
201+
"loop-optimizer-end-callbacks",
202+
]
203+
test_passes("loop", LLVM.loop_passes, skip_loop)
199204
end
200205
end
201206

@@ -409,4 +414,31 @@ end
409414
end
410415
end
411416

417+
@testset "callbacks" begin
418+
# just check that the callbacks can be registered and run without errors
419+
@dispose ctx=Context() begin
420+
# module callbacks
421+
@dispose pb=NewPMPassBuilder() mod=test_module() begin
422+
@test run!("pipeline-start-callbacks<O0>", mod) === nothing
423+
end
424+
@dispose pb=NewPMPassBuilder() mod=test_module() begin
425+
@test run!(PipelineStartCallbacks(opt_level=0), mod) === nothing
426+
end
427+
# CGSCC callback
428+
@dispose pb=NewPMPassBuilder() mod=test_module() begin
429+
@test run!("cgscc-optimizer-late-callbacks<O0>", mod) === nothing
430+
end
431+
@dispose pb=NewPMPassBuilder() mod=test_module() begin
432+
@test run!(CGSCCOptimizerLateCallbacks(opt_level=0), mod) === nothing
433+
end
434+
# function callbacks
435+
@dispose pb=NewPMPassBuilder() mod=test_module() begin
436+
@test run!("peephole-callbacks<O0>", mod) === nothing
437+
end
438+
@dispose pb=NewPMPassBuilder() mod=test_module() begin
439+
@test run!(PeepholeCallbacks(opt_level=0), mod) === nothing
440+
end
441+
end
442+
end
443+
412444
end

0 commit comments

Comments
 (0)