Skip to content

Commit 2f93693

Browse files
[MLIR] Split ExecutionEngine Initialization out of ctor into an explicit method call (llvm#153373)
This PR introduces a mechanism to defer JIT engine initialization, enabling registration of required symbols before global constructor execution. ## Problem Modules containing `gpu.module` generate global constructors (e.g., kernel load/unload) that execute *during* engine creation. This can force premature symbol resolution, causing failures when: - Symbols are registered via `mlirExecutionEngineRegisterSymbol` *after* creation - Global constructors exist (even if not directly using unresolved symbols, e.g., an external function declaration) - GPU modules introduce mandatory binary loading logic ## Usage ```c // Create engine without initialization MlirExecutionEngine jit = mlirExecutionEngineCreate(...); // Register required symbols mlirExecutionEngineRegisterSymbol(jit, ...); // Explicitly initialize (runs global constructors) mlirExecutionEngineInitialize(jit); ``` --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent 4f6ae2a commit 2f93693

File tree

8 files changed

+141
-12
lines changed

8 files changed

+141
-12
lines changed

mlir/include/mlir-c/ExecutionEngine.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(
4646
MlirModule op, int optLevel, int numPaths,
4747
const MlirStringRef *sharedLibPaths, bool enableObjectDump);
4848

49+
/// Initialize the ExecutionEngine. Global constructors specified by
50+
/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
51+
/// binary compiled from `gpu.module` gets loaded during initialization. Make
52+
/// sure all symbols are resolvable before initialization by calling
53+
/// `mlirExecutionEngineRegisterSymbol` or including shared libraries.
54+
MLIR_CAPI_EXPORTED void mlirExecutionEngineInitialize(MlirExecutionEngine jit);
55+
4956
/// Destroy an ExecutionEngine instance.
5057
MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit);
5158

mlir/include/mlir/ExecutionEngine/ExecutionEngine.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,13 @@ class ExecutionEngine {
227227
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
228228
symbolMap);
229229

230+
/// Initialize the ExecutionEngine. Global constructors specified by
231+
/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
232+
/// binary compiled from `gpu.module` gets loaded during initialization. Make
233+
/// sure all symbols are resolvable before initialization by calling
234+
/// `registerSymbols` or including shared libraries.
235+
void initialize();
236+
230237
private:
231238
/// Ordering of llvmContext and jit is important for destruction purposes: the
232239
/// jit must be destroyed before the context.
@@ -250,6 +257,8 @@ class ExecutionEngine {
250257
/// Destroy functions in the libraries loaded by the ExecutionEngine that are
251258
/// called when this ExecutionEngine is destructed.
252259
SmallVector<LibraryDestroyFn> destroyFns;
260+
261+
bool isInitialized = false;
253262
};
254263

255264
} // namespace mlir

mlir/lib/Bindings/Python/ExecutionEngineModule.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir-c/ExecutionEngine.h"
10-
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1110
#include "mlir/Bindings/Python/Nanobind.h"
11+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1212

1313
namespace nb = nanobind;
1414
using namespace mlir;
@@ -124,6 +124,17 @@ NB_MODULE(_mlirExecutionEngine, m) {
124124
},
125125
nb::arg("name"), nb::arg("callback"),
126126
"Register `callback` as the runtime symbol `name`.")
127+
.def(
128+
"initialize",
129+
[](PyExecutionEngine &executionEngine) {
130+
mlirExecutionEngineInitialize(executionEngine.get());
131+
},
132+
"Initialize the ExecutionEngine. Global constructors specified by "
133+
"`llvm.mlir.global_ctors` will be run. One common scenario is that "
134+
"kernel binary compiled from `gpu.module` gets loaded during "
135+
"initialization. Make sure all symbols are resolvable before "
136+
"initialization by calling `raw_register_runtime` or including "
137+
"shared libraries.")
127138
.def(
128139
"dump_to_object_file",
129140
[](PyExecutionEngine &executionEngine, const std::string &fileName) {

mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
6868
return wrap(jitOrError->release());
6969
}
7070

71+
extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) {
72+
unwrap(jit)->initialize();
73+
}
74+
7175
extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) {
7276
delete (unwrap(jit));
7377
}
@@ -106,9 +110,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
106110
void *sym) {
107111
unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
108112
llvm::orc::SymbolMap symbolMap;
109-
symbolMap[interner(unwrap(name))] =
110-
{ llvm::orc::ExecutorAddr::fromPtr(sym),
111-
llvm::JITSymbolFlags::Exported };
113+
symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym),
114+
llvm::JITSymbolFlags::Exported};
112115
return symbolMap;
113116
});
114117
}

mlir/lib/ExecutionEngine/ExecutionEngine.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ void ExecutionEngine::dumpToObjectFile(StringRef filename) {
106106
}
107107
// Compilation is lazy and it doesn't populate object cache unless requested.
108108
// In case object dump is requested before cache is populated, we need to
109-
// force compilation manually.
109+
// force compilation manually.
110110
if (cache->isEmpty()) {
111111
for (std::string &functionName : functionNames) {
112112
auto result = lookupPacked(functionName);
@@ -400,13 +400,6 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
400400
return symbolMap;
401401
};
402402
engine->registerSymbols(runtimeSymbolMap);
403-
404-
// Execute the global constructors from the module being processed.
405-
// TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
406-
// crash for AArch64 see related issue #71963.
407-
if (!engine->jit->getTargetTriple().isAArch64())
408-
cantFail(engine->jit->initialize(engine->jit->getMainJITDylib()));
409-
410403
return std::move(engine);
411404
}
412405

@@ -442,6 +435,7 @@ Expected<void *> ExecutionEngine::lookup(StringRef name) const {
442435

443436
Error ExecutionEngine::invokePacked(StringRef name,
444437
MutableArrayRef<void *> args) {
438+
initialize();
445439
auto expectedFPtr = lookupPacked(name);
446440
if (!expectedFPtr)
447441
return expectedFPtr.takeError();
@@ -451,3 +445,13 @@ Error ExecutionEngine::invokePacked(StringRef name,
451445

452446
return Error::success();
453447
}
448+
449+
void ExecutionEngine::initialize() {
450+
if (isInitialized)
451+
return;
452+
// TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
453+
// crash for AArch64 see related issue #71963.
454+
if (!jit->getTargetTriple().isAArch64())
455+
cantFail(jit->initialize(jit->getMainJITDylib()));
456+
isInitialized = true;
457+
}

mlir/lib/ExecutionEngine/JitRunner.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
202202

203203
auto engine = std::move(*expectedEngine);
204204

205+
engine->initialize();
206+
205207
auto expectedFPtr = engine->lookupPacked(entryPoint);
206208
if (!expectedFPtr)
207209
return expectedFPtr.takeError();

mlir/test/CAPI/execution_engine.c

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,60 @@ void testOmpCreation(void) {
137137
mlirContextDestroy(ctx);
138138
}
139139

140+
// Helper variable to track callback invocations
141+
static int initCnt = 0;
142+
143+
// Callback function that will be called during JIT initialization
144+
static void initCallback(void) { initCnt += 1; }
145+
146+
// CHECK-LABEL: Running test 'testGlobalCtorJitCallback'
147+
void testGlobalCtorJitCallback(void) {
148+
MlirContext ctx = mlirContextCreate();
149+
registerAllUpstreamDialects(ctx);
150+
151+
// Create module with global constructor that calls our callback
152+
MlirModule module = mlirModuleCreateParse(
153+
ctx, mlirStringRefCreateFromCString(
154+
// clang-format off
155+
"module { \n"
156+
" llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero] \n"
157+
" llvm.func @ctor() { \n"
158+
" func.call @init_callback() : () -> () \n"
159+
" llvm.return \n"
160+
" } \n"
161+
" func.func private @init_callback() attributes { llvm.emit_c_interface } \n"
162+
"} \n"
163+
// clang-format on
164+
));
165+
166+
lowerModuleToLLVM(ctx, module);
167+
mlirRegisterAllLLVMTranslations(ctx);
168+
169+
// Create execution engine with initialization disabled
170+
MlirExecutionEngine jit = mlirExecutionEngineCreate(
171+
module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL,
172+
/*enableObjectDump=*/false);
173+
174+
if (mlirExecutionEngineIsNull(jit)) {
175+
fprintf(stderr, "Execution engine creation failed");
176+
exit(2);
177+
}
178+
179+
// Register callback symbol before initialization
180+
mlirExecutionEngineRegisterSymbol(
181+
jit, mlirStringRefCreateFromCString("_mlir_ciface_init_callback"),
182+
(void *)(uintptr_t)initCallback);
183+
184+
mlirExecutionEngineInitialize(jit);
185+
186+
// CHECK: Init count: 1
187+
printf("Init count: %d\n", initCnt);
188+
189+
mlirExecutionEngineDestroy(jit);
190+
mlirModuleDestroy(module);
191+
mlirContextDestroy(ctx);
192+
}
193+
140194
int main(void) {
141195

142196
#define _STRINGIFY(x) #x
@@ -147,5 +201,6 @@ int main(void) {
147201

148202
TEST(testSimpleExecution);
149203
TEST(testOmpCreation);
204+
TEST(testGlobalCtorJitCallback);
150205
return 0;
151206
}

mlir/unittests/ExecutionEngine/Invoke.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,4 +322,42 @@ TEST(NativeMemRefJit, MAYBE_JITCallback) {
322322
ASSERT_EQ(elt, coefficient * count++);
323323
}
324324

325+
static int initCnt = 0;
326+
// A helper function that will be called during the JIT's initialization.
327+
static void initCallback() { initCnt += 1; }
328+
329+
TEST(GlobalCtorJit, MAYBE_JITCallback) {
330+
std::string moduleStr = R"mlir(
331+
llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero]
332+
llvm.func @ctor() {
333+
func.call @init_callback() : () -> ()
334+
llvm.return
335+
}
336+
func.func private @init_callback() attributes { llvm.emit_c_interface }
337+
)mlir";
338+
339+
DialectRegistry registry;
340+
registerAllDialects(registry);
341+
registerBuiltinDialectTranslation(registry);
342+
registerLLVMDialectTranslation(registry);
343+
MLIRContext context(registry);
344+
auto module = parseSourceString<ModuleOp>(moduleStr, &context);
345+
ASSERT_TRUE(!!module);
346+
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
347+
ExecutionEngineOptions jitOptions;
348+
auto jitOrError = ExecutionEngine::create(*module, jitOptions);
349+
ASSERT_TRUE(!!jitOrError);
350+
auto jit = std::move(jitOrError.get());
351+
// Define any extra symbols so they're available at initialization.
352+
jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
353+
llvm::orc::SymbolMap symbolMap;
354+
symbolMap[interner("_mlir_ciface_init_callback")] = {
355+
llvm::orc::ExecutorAddr::fromPtr(initCallback),
356+
llvm::JITSymbolFlags::Exported};
357+
return symbolMap;
358+
});
359+
jit->initialize();
360+
ASSERT_EQ(initCnt, 1);
361+
}
362+
325363
#endif // _WIN32

0 commit comments

Comments
 (0)