Skip to content

Commit 7610b13

Browse files
[MLIR] Split ExecutionEngine Initialization out of ctor into an explicit method call (#153524)
Retry landing #153373 ## Major changes from previous attempt - remove the test in CAPI because no existing tests in CAPI deal with sanitizer exemptions - update `mlir/docs/Dialects/GPU.md` to reflect the new behavior: load GPU binary in global ctors, instead of loading them at call site. - skip the test on Aarch64 since we have an issue with initialization there --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent 40833ee commit 7610b13

File tree

14 files changed

+310
-13
lines changed

14 files changed

+310
-13
lines changed

mlir/docs/Dialects/GPU.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,25 @@ llvm.func @foo() {
193193
// mlir-translate --mlir-to-llvmir:
194194
@binary_bin_cst = internal constant [6 x i8] c"AMDGPU", align 8
195195
@binary_func_kernel_name = private unnamed_addr constant [7 x i8] c"func\00", align 1
196+
@binary_module = internal global ptr null
197+
@llvm.global_ctors = appending global [1 x {i32, ptr, ptr}] [{i32 123, ptr @binary_load, ptr null}]
198+
@llvm.global_dtors = appending global [1 x {i32, ptr, ptr}] [{i32 123, ptr @binary_unload, ptr null}]
199+
define internal void @binary_load() section ".text.startup" {
200+
entry:
201+
%0 = call ptr @mgpuModuleLoad(ptr @binary_bin_cst)
202+
store ptr %0, ptr @binary_module
203+
...
204+
}
205+
define internal void @binary_unload() section ".text.startup" {
206+
entry:
207+
%0 = load ptr, ptr @binary_module, align 8
208+
call void @mgpuModuleUnload(ptr %0)
209+
...
210+
}
196211
...
197212
define void @foo() {
198213
...
199-
%module = call ptr @mgpuModuleLoad(ptr @binary_bin_cst)
214+
%module = load ptr, ptr @binary_module, align 8
200215
%kernel = call ptr @mgpuModuleGetFunction(ptr %module, ptr @binary_func_kernel_name)
201216
call void @mgpuLaunchKernel(ptr %kernel, ...) ; Launch the kernel
202217
...

mlir/include/mlir-c/ExecutionEngine.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(
4646
MlirModule op, int optLevel, int numPaths,
4747
const MlirStringRef *sharedLibPaths, bool enableObjectDump);
4848

49+
/// Initialize the ExecutionEngine. Global constructors specified by
50+
/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
51+
/// binary compiled from `gpu.module` gets loaded during initialization. Make
52+
/// sure all symbols are resolvable before initialization by calling
53+
/// `mlirExecutionEngineRegisterSymbol` or including shared libraries.
54+
MLIR_CAPI_EXPORTED void mlirExecutionEngineInitialize(MlirExecutionEngine jit);
55+
4956
/// Destroy an ExecutionEngine instance.
5057
MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit);
5158

mlir/include/mlir/ExecutionEngine/ExecutionEngine.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,13 @@ class ExecutionEngine {
227227
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
228228
symbolMap);
229229

230+
/// Initialize the ExecutionEngine. Global constructors specified by
231+
/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
232+
/// binary compiled from `gpu.module` gets loaded during initialization. Make
233+
/// sure all symbols are resolvable before initialization by calling
234+
/// `registerSymbols` or including shared libraries.
235+
void initialize();
236+
230237
private:
231238
/// Ordering of llvmContext and jit is important for destruction purposes: the
232239
/// jit must be destroyed before the context.
@@ -250,6 +257,8 @@ class ExecutionEngine {
250257
/// Destroy functions in the libraries loaded by the ExecutionEngine that are
251258
/// called when this ExecutionEngine is destructed.
252259
SmallVector<LibraryDestroyFn> destroyFns;
260+
261+
bool isInitialized = false;
253262
};
254263

255264
} // namespace mlir

mlir/lib/Bindings/Python/ExecutionEngineModule.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir-c/ExecutionEngine.h"
10-
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1110
#include "mlir/Bindings/Python/Nanobind.h"
11+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
1212

1313
namespace nb = nanobind;
1414
using namespace mlir;
@@ -124,6 +124,17 @@ NB_MODULE(_mlirExecutionEngine, m) {
124124
},
125125
nb::arg("name"), nb::arg("callback"),
126126
"Register `callback` as the runtime symbol `name`.")
127+
.def(
128+
"initialize",
129+
[](PyExecutionEngine &executionEngine) {
130+
mlirExecutionEngineInitialize(executionEngine.get());
131+
},
132+
"Initialize the ExecutionEngine. Global constructors specified by "
133+
"`llvm.mlir.global_ctors` will be run. One common scenario is that "
134+
"kernel binary compiled from `gpu.module` gets loaded during "
135+
"initialization. Make sure all symbols are resolvable before "
136+
"initialization by calling `register_runtime` or including "
137+
"shared libraries.")
127138
.def(
128139
"dump_to_object_file",
129140
[](PyExecutionEngine &executionEngine, const std::string &fileName) {

mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
6868
return wrap(jitOrError->release());
6969
}
7070

71+
extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) {
72+
unwrap(jit)->initialize();
73+
}
74+
7175
extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) {
7276
delete (unwrap(jit));
7377
}
@@ -106,9 +110,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
106110
void *sym) {
107111
unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
108112
llvm::orc::SymbolMap symbolMap;
109-
symbolMap[interner(unwrap(name))] =
110-
{ llvm::orc::ExecutorAddr::fromPtr(sym),
111-
llvm::JITSymbolFlags::Exported };
113+
symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym),
114+
llvm::JITSymbolFlags::Exported};
112115
return symbolMap;
113116
});
114117
}

mlir/lib/ExecutionEngine/ExecutionEngine.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ void ExecutionEngine::dumpToObjectFile(StringRef filename) {
106106
}
107107
// Compilation is lazy and it doesn't populate object cache unless requested.
108108
// In case object dump is requested before cache is populated, we need to
109-
// force compilation manually.
109+
// force compilation manually.
110110
if (cache->isEmpty()) {
111111
for (std::string &functionName : functionNames) {
112112
auto result = lookupPacked(functionName);
@@ -400,13 +400,6 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
400400
return symbolMap;
401401
};
402402
engine->registerSymbols(runtimeSymbolMap);
403-
404-
// Execute the global constructors from the module being processed.
405-
// TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
406-
// crash for AArch64 see related issue #71963.
407-
if (!engine->jit->getTargetTriple().isAArch64())
408-
cantFail(engine->jit->initialize(engine->jit->getMainJITDylib()));
409-
410403
return std::move(engine);
411404
}
412405

@@ -442,6 +435,7 @@ Expected<void *> ExecutionEngine::lookup(StringRef name) const {
442435

443436
Error ExecutionEngine::invokePacked(StringRef name,
444437
MutableArrayRef<void *> args) {
438+
initialize();
445439
auto expectedFPtr = lookupPacked(name);
446440
if (!expectedFPtr)
447441
return expectedFPtr.takeError();
@@ -451,3 +445,13 @@ Error ExecutionEngine::invokePacked(StringRef name,
451445

452446
return Error::success();
453447
}
448+
449+
void ExecutionEngine::initialize() {
450+
if (isInitialized)
451+
return;
452+
// TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
453+
// crash for AArch64 see related issue #71963.
454+
if (!jit->getTargetTriple().isAArch64())
455+
cantFail(jit->initialize(jit->getMainJITDylib()));
456+
isInitialized = true;
457+
}

mlir/lib/ExecutionEngine/JitRunner.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
202202

203203
auto engine = std::move(*expectedEngine);
204204

205+
engine->initialize();
206+
205207
auto expectedFPtr = engine->lookupPacked(entryPoint);
206208
if (!expectedFPtr)
207209
return expectedFPtr.takeError();

mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ class ExecutionEngine:
1919
def dump_to_object_file(self, file_name: str) -> None: ...
2020
def raw_lookup(self, func_name: str) -> int: ...
2121
def raw_register_runtime(self, name: str, callback: object) -> None: ...
22+
def init() -> None: ...
2223
@property
2324
def _CAPIPtr(self) -> object: ...

mlir/test/CAPI/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ if(MLIR_ENABLE_EXECUTION_ENGINE)
3030
MLIRCAPIConversion
3131
MLIRCAPIExecutionEngine
3232
MLIRCAPIRegisterEverything
33+
)
34+
_add_capi_test_executable(mlir-capi-global-constructors-test
35+
global_constructors.c
36+
LINK_LIBS PRIVATE
37+
MLIRCAPIConversion
38+
MLIRCAPIExecutionEngine
39+
MLIRCAPIRegisterEverything
3340
)
3441
endif()
3542

mlir/test/CAPI/global_constructors.c

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
//===- global_constructors.c - Test JIT with the global constructors ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
// UNSUPPORTED: target=aarch64{{.*}}, target=arm64{{.*}}
11+
/* RUN: mlir-capi-global-constructors-test 2>&1 | FileCheck %s
12+
*/
13+
/* REQUIRES: host-supports-jit
14+
*/
15+
16+
#include "mlir-c/Conversion.h"
17+
#include "mlir-c/ExecutionEngine.h"
18+
#include "mlir-c/IR.h"
19+
#include "mlir-c/RegisterEverything.h"
20+
21+
#include <assert.h>
22+
#include <math.h>
23+
#include <stdio.h>
24+
#include <stdlib.h>
25+
#include <string.h>
26+
27+
static void registerAllUpstreamDialects(MlirContext ctx) {
28+
MlirDialectRegistry registry = mlirDialectRegistryCreate();
29+
mlirRegisterAllDialects(registry);
30+
mlirContextAppendDialectRegistry(ctx, registry);
31+
mlirDialectRegistryDestroy(registry);
32+
}
33+
34+
void lowerModuleToLLVM(MlirContext ctx, MlirModule module) {
35+
MlirPassManager pm = mlirPassManagerCreate(ctx);
36+
MlirOpPassManager opm = mlirPassManagerGetNestedUnder(
37+
pm, mlirStringRefCreateFromCString("func.func"));
38+
mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVMPass());
39+
mlirOpPassManagerAddOwnedPass(
40+
opm, mlirCreateConversionArithToLLVMConversionPass());
41+
MlirLogicalResult status =
42+
mlirPassManagerRunOnOp(pm, mlirModuleGetOperation(module));
43+
if (mlirLogicalResultIsFailure(status)) {
44+
fprintf(stderr, "Unexpected failure running pass pipeline\n");
45+
exit(2);
46+
}
47+
mlirPassManagerDestroy(pm);
48+
}
49+
50+
// Helper variable to track callback invocations
51+
static int initCnt = 0;
52+
53+
// Callback function that will be called during JIT initialization
54+
static void initCallback(void) { initCnt += 1; }
55+
56+
// CHECK-LABEL: Running test 'testGlobalCtorJitCallback'
57+
void testGlobalCtorJitCallback(void) {
58+
MlirContext ctx = mlirContextCreate();
59+
registerAllUpstreamDialects(ctx);
60+
61+
// Create module with global constructor that calls our callback
62+
MlirModule module = mlirModuleCreateParse(
63+
ctx, mlirStringRefCreateFromCString(
64+
// clang-format off
65+
"module { \n"
66+
" llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero] \n"
67+
" llvm.func @ctor() { \n"
68+
" func.call @init_callback() : () -> () \n"
69+
" llvm.return \n"
70+
" } \n"
71+
" func.func private @init_callback() attributes { llvm.emit_c_interface } \n"
72+
"} \n"
73+
// clang-format on
74+
));
75+
76+
lowerModuleToLLVM(ctx, module);
77+
mlirRegisterAllLLVMTranslations(ctx);
78+
79+
// Create execution engine with initialization disabled
80+
MlirExecutionEngine jit = mlirExecutionEngineCreate(
81+
module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL,
82+
/*enableObjectDump=*/false);
83+
84+
if (mlirExecutionEngineIsNull(jit)) {
85+
fprintf(stderr, "Execution engine creation failed");
86+
exit(2);
87+
}
88+
89+
// Register callback symbol before initialization
90+
mlirExecutionEngineRegisterSymbol(
91+
jit, mlirStringRefCreateFromCString("_mlir_ciface_init_callback"),
92+
(void *)(uintptr_t)initCallback);
93+
94+
mlirExecutionEngineInitialize(jit);
95+
96+
// CHECK: Init count: 1
97+
printf("Init count: %d\n", initCnt);
98+
99+
mlirExecutionEngineDestroy(jit);
100+
mlirModuleDestroy(module);
101+
mlirContextDestroy(ctx);
102+
}
103+
104+
int main(void) {
105+
106+
#define _STRINGIFY(x) #x
107+
#define STRINGIFY(x) _STRINGIFY(x)
108+
#define TEST(test) \
109+
printf("Running test '" STRINGIFY(test) "'\n"); \
110+
test();
111+
TEST(testGlobalCtorJitCallback);
112+
return 0;
113+
}

0 commit comments

Comments
 (0)