diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index 99cddc5c2598d..1a58d68533f24 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -46,6 +46,13 @@ MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate( MlirModule op, int optLevel, int numPaths, const MlirStringRef *sharedLibPaths, bool enableObjectDump); +/// Initialize the ExecutionEngine. Global constructors specified by +/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel +/// binary compiled from `gpu.module` gets loaded during initialization. Make +/// sure all symbols are resolvable before initialization by calling +/// `mlirExecutionEngineRegisterSymbol` or including shared libraries. +MLIR_CAPI_EXPORTED void mlirExecutionEngineInitialize(MlirExecutionEngine jit); + /// Destroy an ExecutionEngine instance. MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit); diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h index 96ccebcd5685e..5bd71d68d253a 100644 --- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -227,6 +227,13 @@ class ExecutionEngine { llvm::function_ref symbolMap); + /// Initialize the ExecutionEngine. Global constructors specified by + /// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel + /// binary compiled from `gpu.module` gets loaded during initialization. Make + /// sure all symbols are resolvable before initialization by calling + /// `registerSymbols` or including shared libraries. + void initialize(); + private: /// Ordering of llvmContext and jit is important for destruction purposes: the /// jit must be destroyed before the context. @@ -250,6 +257,8 @@ class ExecutionEngine { /// Destroy functions in the libraries loaded by the ExecutionEngine that are /// called when this ExecutionEngine is destructed. SmallVector destroyFns; + + bool isInitialized = false; }; } // namespace mlir diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 81dada3553622..4f7a4a628e246 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir-c/ExecutionEngine.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace mlir; @@ -124,6 +124,17 @@ NB_MODULE(_mlirExecutionEngine, m) { }, nb::arg("name"), nb::arg("callback"), "Register `callback` as the runtime symbol `name`.") + .def( + "initialize", + [](PyExecutionEngine &executionEngine) { + mlirExecutionEngineInitialize(executionEngine.get()); + }, + "Initialize the ExecutionEngine. Global constructors specified by " + "`llvm.mlir.global_ctors` will be run. One common scenario is that " + "kernel binary compiled from `gpu.module` gets loaded during " + "initialization. Make sure all symbols are resolvable before " + "initialization by calling `raw_register_runtime` or including " + "shared libraries.") .def( "dump_to_object_file", [](PyExecutionEngine &executionEngine, const std::string &fileName) { diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 306cebd236be9..2dbb993b1640f 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -68,6 +68,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, return wrap(jitOrError->release()); } +extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) { + unwrap(jit)->initialize(); +} + extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) { delete (unwrap(jit)); } @@ -106,9 +110,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, void *sym) { unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) { llvm::orc::SymbolMap symbolMap; - symbolMap[interner(unwrap(name))] = - { llvm::orc::ExecutorAddr::fromPtr(sym), - llvm::JITSymbolFlags::Exported }; + symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym), + llvm::JITSymbolFlags::Exported}; return symbolMap; }); } diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index f704fbfbe8fff..52162a43aeae3 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -106,7 +106,7 @@ void ExecutionEngine::dumpToObjectFile(StringRef filename) { } // Compilation is lazy and it doesn't populate object cache unless requested. // In case object dump is requested before cache is populated, we need to - // force compilation manually. + // force compilation manually. if (cache->isEmpty()) { for (std::string &functionName : functionNames) { auto result = lookupPacked(functionName); @@ -400,13 +400,6 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options, return symbolMap; }; engine->registerSymbols(runtimeSymbolMap); - - // Execute the global constructors from the module being processed. - // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a - // crash for AArch64 see related issue #71963. - if (!engine->jit->getTargetTriple().isAArch64()) - cantFail(engine->jit->initialize(engine->jit->getMainJITDylib())); - return std::move(engine); } @@ -442,6 +435,7 @@ Expected ExecutionEngine::lookup(StringRef name) const { Error ExecutionEngine::invokePacked(StringRef name, MutableArrayRef args) { + initialize(); auto expectedFPtr = lookupPacked(name); if (!expectedFPtr) return expectedFPtr.takeError(); @@ -451,3 +445,13 @@ Error ExecutionEngine::invokePacked(StringRef name, return Error::success(); } + +void ExecutionEngine::initialize() { + if (isInitialized) + return; + // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a + // crash for AArch64 see related issue #71963. + if (!jit->getTargetTriple().isAArch64()) + cantFail(jit->initialize(jit->getMainJITDylib())); + isInitialized = true; +} diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp index 2107df37d1997..0ada4cc96570a 100644 --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -202,6 +202,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint, auto engine = std::move(*expectedEngine); + engine->initialize(); + auto expectedFPtr = engine->lookupPacked(entryPoint); if (!expectedFPtr) return expectedFPtr.takeError(); diff --git a/mlir/test/CAPI/execution_engine.c b/mlir/test/CAPI/execution_engine.c index 4751288c3ee4b..900588aeea2da 100644 --- a/mlir/test/CAPI/execution_engine.c +++ b/mlir/test/CAPI/execution_engine.c @@ -137,6 +137,60 @@ void testOmpCreation(void) { mlirContextDestroy(ctx); } +// Helper variable to track callback invocations +static int initCnt = 0; + +// Callback function that will be called during JIT initialization +static void initCallback(void) { initCnt += 1; } + +// CHECK-LABEL: Running test 'testGlobalCtorJitCallback' +void testGlobalCtorJitCallback(void) { + MlirContext ctx = mlirContextCreate(); + registerAllUpstreamDialects(ctx); + + // Create module with global constructor that calls our callback + MlirModule module = mlirModuleCreateParse( + ctx, mlirStringRefCreateFromCString( + // clang-format off +"module { \n" +" llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero] \n" +" llvm.func @ctor() { \n" +" func.call @init_callback() : () -> () \n" +" llvm.return \n" +" } \n" +" func.func private @init_callback() attributes { llvm.emit_c_interface } \n" +"} \n" + // clang-format on + )); + + lowerModuleToLLVM(ctx, module); + mlirRegisterAllLLVMTranslations(ctx); + + // Create execution engine with initialization disabled + MlirExecutionEngine jit = mlirExecutionEngineCreate( + module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL, + /*enableObjectDump=*/false); + + if (mlirExecutionEngineIsNull(jit)) { + fprintf(stderr, "Execution engine creation failed"); + exit(2); + } + + // Register callback symbol before initialization + mlirExecutionEngineRegisterSymbol( + jit, mlirStringRefCreateFromCString("_mlir_ciface_init_callback"), + (void *)(uintptr_t)initCallback); + + mlirExecutionEngineInitialize(jit); + + // CHECK: Init count: 1 + printf("Init count: %d\n", initCnt); + + mlirExecutionEngineDestroy(jit); + mlirModuleDestroy(module); + mlirContextDestroy(ctx); +} + int main(void) { #define _STRINGIFY(x) #x @@ -147,5 +201,6 @@ int main(void) { TEST(testSimpleExecution); TEST(testOmpCreation); + TEST(testGlobalCtorJitCallback); return 0; } diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp index 887db227cfc4b..e2f2fe0215b91 100644 --- a/mlir/unittests/ExecutionEngine/Invoke.cpp +++ b/mlir/unittests/ExecutionEngine/Invoke.cpp @@ -316,4 +316,42 @@ TEST(NativeMemRefJit, MAYBE_JITCallback) { ASSERT_EQ(elt, coefficient * count++); } +static int initCnt = 0; +// A helper function that will be called during the JIT's initialization. +static void initCallback() { initCnt += 1; } + +TEST(GlobalCtorJit, MAYBE_JITCallback) { + std::string moduleStr = R"mlir( + llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero] + llvm.func @ctor() { + func.call @init_callback() : () -> () + llvm.return + } + func.func private @init_callback() attributes { llvm.emit_c_interface } + )mlir"; + + DialectRegistry registry; + registerAllDialects(registry); + registerBuiltinDialectTranslation(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); + auto module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + ExecutionEngineOptions jitOptions; + auto jitOrError = ExecutionEngine::create(*module, jitOptions); + ASSERT_TRUE(!!jitOrError); + auto jit = std::move(jitOrError.get()); + // Define any extra symbols so they're available at initialization. + jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { + llvm::orc::SymbolMap symbolMap; + symbolMap[interner("_mlir_ciface_init_callback")] = { + llvm::orc::ExecutorAddr::fromPtr(initCallback), + llvm::JITSymbolFlags::Exported}; + return symbolMap; + }); + jit->initialize(); + ASSERT_EQ(initCnt, 1); +} + #endif // _WIN32