diff --git a/mlir/docs/Dialects/GPU.md b/mlir/docs/Dialects/GPU.md index 94b053daa1615..8d4d2ca3e5743 100644 --- a/mlir/docs/Dialects/GPU.md +++ b/mlir/docs/Dialects/GPU.md @@ -193,10 +193,25 @@ llvm.func @foo() { // mlir-translate --mlir-to-llvmir: @binary_bin_cst = internal constant [6 x i8] c"AMDGPU", align 8 @binary_func_kernel_name = private unnamed_addr constant [7 x i8] c"func\00", align 1 +@binary_module = internal global ptr null +@llvm.global_ctors = appending global [1 x {i32, ptr, ptr}] [{i32 123, ptr @binary_load, ptr null}] +@llvm.global_dtors = appending global [1 x {i32, ptr, ptr}] [{i32 123, ptr @binary_unload, ptr null}] +define internal void @binary_load() section ".text.startup" { +entry: + %0 = call ptr @mgpuModuleLoad(ptr @binary_bin_cst) + store ptr %0, ptr @binary_module + ... +} +define internal void @binary_unload() section ".text.startup" { +entry: + %0 = load ptr, ptr @binary_module, align 8 + call void @mgpuModuleUnload(ptr %0) + ... +} ... define void @foo() { ... - %module = call ptr @mgpuModuleLoad(ptr @binary_bin_cst) + %module = load ptr, ptr @binary_module, align 8 %kernel = call ptr @mgpuModuleGetFunction(ptr %module, ptr @binary_func_kernel_name) call void @mgpuLaunchKernel(ptr %kernel, ...) ; Launch the kernel ... diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index 99cddc5c2598d..1a58d68533f24 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -46,6 +46,13 @@ MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate( MlirModule op, int optLevel, int numPaths, const MlirStringRef *sharedLibPaths, bool enableObjectDump); +/// Initialize the ExecutionEngine. Global constructors specified by +/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel +/// binary compiled from `gpu.module` gets loaded during initialization. Make +/// sure all symbols are resolvable before initialization by calling +/// `mlirExecutionEngineRegisterSymbol` or including shared libraries. +MLIR_CAPI_EXPORTED void mlirExecutionEngineInitialize(MlirExecutionEngine jit); + /// Destroy an ExecutionEngine instance. MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit); diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h index 96ccebcd5685e..5bd71d68d253a 100644 --- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -227,6 +227,13 @@ class ExecutionEngine { llvm::function_ref symbolMap); + /// Initialize the ExecutionEngine. Global constructors specified by + /// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel + /// binary compiled from `gpu.module` gets loaded during initialization. Make + /// sure all symbols are resolvable before initialization by calling + /// `registerSymbols` or including shared libraries. + void initialize(); + private: /// Ordering of llvmContext and jit is important for destruction purposes: the /// jit must be destroyed before the context. @@ -250,6 +257,8 @@ class ExecutionEngine { /// Destroy functions in the libraries loaded by the ExecutionEngine that are /// called when this ExecutionEngine is destructed. SmallVector destroyFns; + + bool isInitialized = false; }; } // namespace mlir diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 81dada3553622..4885d62c56e6e 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir-c/ExecutionEngine.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace mlir; @@ -124,6 +124,17 @@ NB_MODULE(_mlirExecutionEngine, m) { }, nb::arg("name"), nb::arg("callback"), "Register `callback` as the runtime symbol `name`.") + .def( + "initialize", + [](PyExecutionEngine &executionEngine) { + mlirExecutionEngineInitialize(executionEngine.get()); + }, + "Initialize the ExecutionEngine. Global constructors specified by " + "`llvm.mlir.global_ctors` will be run. One common scenario is that " + "kernel binary compiled from `gpu.module` gets loaded during " + "initialization. Make sure all symbols are resolvable before " + "initialization by calling `register_runtime` or including " + "shared libraries.") .def( "dump_to_object_file", [](PyExecutionEngine &executionEngine, const std::string &fileName) { diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 306cebd236be9..2dbb993b1640f 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -68,6 +68,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, return wrap(jitOrError->release()); } +extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) { + unwrap(jit)->initialize(); +} + extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) { delete (unwrap(jit)); } @@ -106,9 +110,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, void *sym) { unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) { llvm::orc::SymbolMap symbolMap; - symbolMap[interner(unwrap(name))] = - { llvm::orc::ExecutorAddr::fromPtr(sym), - llvm::JITSymbolFlags::Exported }; + symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym), + llvm::JITSymbolFlags::Exported}; return symbolMap; }); } diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index f704fbfbe8fff..52162a43aeae3 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -106,7 +106,7 @@ void ExecutionEngine::dumpToObjectFile(StringRef filename) { } // Compilation is lazy and it doesn't populate object cache unless requested. // In case object dump is requested before cache is populated, we need to - // force compilation manually. + // force compilation manually. if (cache->isEmpty()) { for (std::string &functionName : functionNames) { auto result = lookupPacked(functionName); @@ -400,13 +400,6 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options, return symbolMap; }; engine->registerSymbols(runtimeSymbolMap); - - // Execute the global constructors from the module being processed. - // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a - // crash for AArch64 see related issue #71963. - if (!engine->jit->getTargetTriple().isAArch64()) - cantFail(engine->jit->initialize(engine->jit->getMainJITDylib())); - return std::move(engine); } @@ -442,6 +435,7 @@ Expected ExecutionEngine::lookup(StringRef name) const { Error ExecutionEngine::invokePacked(StringRef name, MutableArrayRef args) { + initialize(); auto expectedFPtr = lookupPacked(name); if (!expectedFPtr) return expectedFPtr.takeError(); @@ -451,3 +445,13 @@ Error ExecutionEngine::invokePacked(StringRef name, return Error::success(); } + +void ExecutionEngine::initialize() { + if (isInitialized) + return; + // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a + // crash for AArch64 see related issue #71963. + if (!jit->getTargetTriple().isAArch64()) + cantFail(jit->initialize(jit->getMainJITDylib())); + isInitialized = true; +} diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp index 2107df37d1997..0ada4cc96570a 100644 --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -202,6 +202,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint, auto engine = std::move(*expectedEngine); + engine->initialize(); + auto expectedFPtr = engine->lookupPacked(entryPoint); if (!expectedFPtr) return expectedFPtr.takeError(); diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi index 58d453d2b2d37..4b82c78489295 100644 --- a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi @@ -19,5 +19,6 @@ class ExecutionEngine: def dump_to_object_file(self, file_name: str) -> None: ... def raw_lookup(self, func_name: str) -> int: ... def raw_register_runtime(self, name: str, callback: object) -> None: ... + def init() -> None: ... @property def _CAPIPtr(self) -> object: ... diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt index a7f9eb9b4efe8..d45142510a496 100644 --- a/mlir/test/CAPI/CMakeLists.txt +++ b/mlir/test/CAPI/CMakeLists.txt @@ -30,6 +30,13 @@ if(MLIR_ENABLE_EXECUTION_ENGINE) MLIRCAPIConversion MLIRCAPIExecutionEngine MLIRCAPIRegisterEverything +) + _add_capi_test_executable(mlir-capi-global-constructors-test + global_constructors.c + LINK_LIBS PRIVATE + MLIRCAPIConversion + MLIRCAPIExecutionEngine + MLIRCAPIRegisterEverything ) endif() diff --git a/mlir/test/CAPI/global_constructors.c b/mlir/test/CAPI/global_constructors.c new file mode 100644 index 0000000000000..bd2fe1416f0df --- /dev/null +++ b/mlir/test/CAPI/global_constructors.c @@ -0,0 +1,113 @@ +//===- global_constructors.c - Test JIT with the global constructors ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// UNSUPPORTED: target=aarch64{{.*}}, target=arm64{{.*}} +/* RUN: mlir-capi-global-constructors-test 2>&1 | FileCheck %s + */ +/* REQUIRES: host-supports-jit + */ + +#include "mlir-c/Conversion.h" +#include "mlir-c/ExecutionEngine.h" +#include "mlir-c/IR.h" +#include "mlir-c/RegisterEverything.h" + +#include +#include +#include +#include +#include + +static void registerAllUpstreamDialects(MlirContext ctx) { + MlirDialectRegistry registry = mlirDialectRegistryCreate(); + mlirRegisterAllDialects(registry); + mlirContextAppendDialectRegistry(ctx, registry); + mlirDialectRegistryDestroy(registry); +} + +void lowerModuleToLLVM(MlirContext ctx, MlirModule module) { + MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirOpPassManager opm = mlirPassManagerGetNestedUnder( + pm, mlirStringRefCreateFromCString("func.func")); + mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVMPass()); + mlirOpPassManagerAddOwnedPass( + opm, mlirCreateConversionArithToLLVMConversionPass()); + MlirLogicalResult status = + mlirPassManagerRunOnOp(pm, mlirModuleGetOperation(module)); + if (mlirLogicalResultIsFailure(status)) { + fprintf(stderr, "Unexpected failure running pass pipeline\n"); + exit(2); + } + mlirPassManagerDestroy(pm); +} + +// Helper variable to track callback invocations +static int initCnt = 0; + +// Callback function that will be called during JIT initialization +static void initCallback(void) { initCnt += 1; } + +// CHECK-LABEL: Running test 'testGlobalCtorJitCallback' +void testGlobalCtorJitCallback(void) { + MlirContext ctx = mlirContextCreate(); + registerAllUpstreamDialects(ctx); + + // Create module with global constructor that calls our callback + MlirModule module = mlirModuleCreateParse( + ctx, mlirStringRefCreateFromCString( + // clang-format off +"module { \n" +" llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero] \n" +" llvm.func @ctor() { \n" +" func.call @init_callback() : () -> () \n" +" llvm.return \n" +" } \n" +" func.func private @init_callback() attributes { llvm.emit_c_interface } \n" +"} \n" + // clang-format on + )); + + lowerModuleToLLVM(ctx, module); + mlirRegisterAllLLVMTranslations(ctx); + + // Create execution engine with initialization disabled + MlirExecutionEngine jit = mlirExecutionEngineCreate( + module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL, + /*enableObjectDump=*/false); + + if (mlirExecutionEngineIsNull(jit)) { + fprintf(stderr, "Execution engine creation failed"); + exit(2); + } + + // Register callback symbol before initialization + mlirExecutionEngineRegisterSymbol( + jit, mlirStringRefCreateFromCString("_mlir_ciface_init_callback"), + (void *)(uintptr_t)initCallback); + + mlirExecutionEngineInitialize(jit); + + // CHECK: Init count: 1 + printf("Init count: %d\n", initCnt); + + mlirExecutionEngineDestroy(jit); + mlirModuleDestroy(module); + mlirContextDestroy(ctx); +} + +int main(void) { + +#define _STRINGIFY(x) #x +#define STRINGIFY(x) _STRINGIFY(x) +#define TEST(test) \ + printf("Running test '" STRINGIFY(test) "'\n"); \ + test(); + TEST(testGlobalCtorJitCallback); + return 0; +} diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index a4a942de3c9a7..0b98eaaf3391c 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -141,6 +141,7 @@ if(LLVM_ENABLE_PIC AND TARGET ${LLVM_NATIVE_ARCH}) llc mlir_async_runtime mlir-capi-execution-engine-test + mlir-capi-global-constructors-test mlir_c_runner_utils mlir_runner_utils mlir_float16_utils diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index ba7eeeed8ef3f..5e9347d784b38 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -190,6 +190,7 @@ def find_real_python_interpreter(): "mlir-translate", "mlir-lsp-server", "mlir-capi-execution-engine-test", + "mlir-capi-global-constructors-test", "mlir-capi-ir-test", "mlir-capi-irdl-test", "mlir-capi-llvm-test", diff --git a/mlir/test/python/global_constructors.py b/mlir/test/python/global_constructors.py new file mode 100644 index 0000000000000..5020c00344a33 --- /dev/null +++ b/mlir/test/python/global_constructors.py @@ -0,0 +1,72 @@ +# UNSUPPORTED: target=aarch64{{.*}}, target=arm64{{.*}} +# RUN: %PYTHON %s 2>&1 | FileCheck %s +# REQUIRES: host-supports-jit +import gc, sys, os, tempfile +from mlir.ir import * +from mlir.passmanager import * +from mlir.execution_engine import * +from mlir.runtime import * + + +# Log everything to stderr and flush so that we have a unified stream to match +# errors/info emitted by MLIR to stderr. +def log(*args): + print(*args, file=sys.stderr) + sys.stderr.flush() + + +def run(f): + log("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + + +def lowerToLLVM(module): + pm = PassManager.parse( + "builtin.module(convert-func-to-llvm,reconcile-unrealized-casts)" + ) + pm.run(module.operation) + return module + + +# Test JIT callback in global constructor +# CHECK-LABEL: TEST: testJITCallbackInGlobalCtor +def testJITCallbackInGlobalCtor(): + init_cnt = 0 + + @ctypes.CFUNCTYPE(None) + def initCallback(): + nonlocal init_cnt + init_cnt += 1 + + with Context(): + module = Module.parse( + r""" +llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero] +llvm.func @ctor() { + func.call @init_callback() : () -> () + llvm.return +} +func.func private @init_callback() attributes { llvm.emit_c_interface } + """ + ) + + # Setup execution engine + execution_engine = ExecutionEngine(lowerToLLVM(module)) + + # Validate initialization hasn't run yet + assert init_cnt == 0 + + # # Register callback + execution_engine.register_runtime("init_callback", initCallback) + + # # Initialize and verify + execution_engine.initialize() + assert init_cnt == 1 + # # Second initialization should be no-op + execution_engine.initialize() + assert init_cnt == 1 + + +run(testJITCallbackInGlobalCtor) diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp index 312b10f28143f..b9a46c5ce942f 100644 --- a/mlir/unittests/ExecutionEngine/Invoke.cpp +++ b/mlir/unittests/ExecutionEngine/Invoke.cpp @@ -322,4 +322,55 @@ TEST(NativeMemRefJit, MAYBE_JITCallback) { ASSERT_EQ(elt, coefficient * count++); } +static int initCnt = 0; +// A helper function that will be called during the JIT's initialization. +static void initCallback() { initCnt += 1; } + +TEST(MLIRExecutionEngine, MAYBE_JITCallbackInGlobalCtor) { + auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); + ASSERT_TRUE(!!tmBuilderOrError); + if (tmBuilderOrError->getTargetTriple().isAArch64()) { + GTEST_SKIP() << "Skipping global ctor initialization test on Aarch64 " + "because of bug #71963"; + return; + } + std::string moduleStr = R"mlir( + llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero] + llvm.func @ctor() { + func.call @init_callback() : () -> () + llvm.return + } + func.func private @init_callback() attributes { llvm.emit_c_interface } + )mlir"; + + DialectRegistry registry; + registerAllDialects(registry); + registerBuiltinDialectTranslation(registry); + registerLLVMDialectTranslation(registry); + MLIRContext context(registry); + auto module = parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); + ExecutionEngineOptions jitOptions; + auto jitOrError = ExecutionEngine::create(*module, jitOptions); + ASSERT_TRUE(!!jitOrError); + // validate initialization is not run on construction + ASSERT_EQ(initCnt, 0); + auto jit = std::move(jitOrError.get()); + // Define any extra symbols so they're available at initialization. + jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { + llvm::orc::SymbolMap symbolMap; + symbolMap[interner("_mlir_ciface_init_callback")] = { + llvm::orc::ExecutorAddr::fromPtr(initCallback), + llvm::JITSymbolFlags::Exported}; + return symbolMap; + }); + jit->initialize(); + // validate the side effect of initialization + ASSERT_EQ(initCnt, 1); + // next initialization should be noop + jit->initialize(); + ASSERT_EQ(initCnt, 1); +} + #endif // _WIN32