Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion mlir/docs/Dialects/GPU.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,25 @@ llvm.func @foo() {
// mlir-translate --mlir-to-llvmir:
@binary_bin_cst = internal constant [6 x i8] c"AMDGPU", align 8
@binary_func_kernel_name = private unnamed_addr constant [7 x i8] c"func\00", align 1
@binary_module = internal global ptr null
@llvm.global_ctors = appending global [1 x {i32, ptr, ptr}] [{i32 123, ptr @binary_load, ptr null}]
@llvm.global_dtors = appending global [1 x {i32, ptr, ptr}] [{i32 123, ptr @binary_unload, ptr null}]
define internal void @binary_load() section ".text.startup" {
entry:
%0 = call ptr @mgpuModuleLoad(ptr @binary_bin_cst)
store ptr %0, ptr @binary_module
...
}
define internal void @binary_unload() section ".text.startup" {
entry:
%0 = load ptr, ptr @binary_module, align 8
call void @mgpuModuleUnload(ptr %0)
...
}
...
define void @foo() {
...
%module = call ptr @mgpuModuleLoad(ptr @binary_bin_cst)
%module = load ptr, ptr @binary_module, align 8
%kernel = call ptr @mgpuModuleGetFunction(ptr %module, ptr @binary_func_kernel_name)
call void @mgpuLaunchKernel(ptr %kernel, ...) ; Launch the kernel
...
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir-c/ExecutionEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(
MlirModule op, int optLevel, int numPaths,
const MlirStringRef *sharedLibPaths, bool enableObjectDump);

/// Initialize the ExecutionEngine. Global constructors specified by
/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
/// binary compiled from `gpu.module` gets loaded during initialization. Make
/// sure all symbols are resolvable before initialization by calling
/// `mlirExecutionEngineRegisterSymbol` or including shared libraries.
MLIR_CAPI_EXPORTED void mlirExecutionEngineInitialize(MlirExecutionEngine jit);

/// Destroy an ExecutionEngine instance.
MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit);

Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ class ExecutionEngine {
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
symbolMap);

/// Initialize the ExecutionEngine. Global constructors specified by
/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
/// binary compiled from `gpu.module` gets loaded during initialization. Make
/// sure all symbols are resolvable before initialization by calling
/// `registerSymbols` or including shared libraries.
void initialize();

private:
/// Ordering of llvmContext and jit is important for destruction purposes: the
/// jit must be destroyed before the context.
Expand All @@ -250,6 +257,8 @@ class ExecutionEngine {
/// Destroy functions in the libraries loaded by the ExecutionEngine that are
/// called when this ExecutionEngine is destructed.
SmallVector<LibraryDestroyFn> destroyFns;

bool isInitialized = false;
};

} // namespace mlir
Expand Down
13 changes: 12 additions & 1 deletion mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//

#include "mlir-c/ExecutionEngine.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace nb = nanobind;
using namespace mlir;
Expand Down Expand Up @@ -124,6 +124,17 @@ NB_MODULE(_mlirExecutionEngine, m) {
},
nb::arg("name"), nb::arg("callback"),
"Register `callback` as the runtime symbol `name`.")
.def(
"initialize",
[](PyExecutionEngine &executionEngine) {
mlirExecutionEngineInitialize(executionEngine.get());
},
"Initialize the ExecutionEngine. Global constructors specified by "
"`llvm.mlir.global_ctors` will be run. One common scenario is that "
"kernel binary compiled from `gpu.module` gets loaded during "
"initialization. Make sure all symbols are resolvable before "
"initialization by calling `register_runtime` or including "
"shared libraries.")
.def(
"dump_to_object_file",
[](PyExecutionEngine &executionEngine, const std::string &fileName) {
Expand Down
9 changes: 6 additions & 3 deletions mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
return wrap(jitOrError->release());
}

extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) {
unwrap(jit)->initialize();
}

extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) {
delete (unwrap(jit));
}
Expand Down Expand Up @@ -106,9 +110,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
void *sym) {
unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
llvm::orc::SymbolMap symbolMap;
symbolMap[interner(unwrap(name))] =
{ llvm::orc::ExecutorAddr::fromPtr(sym),
llvm::JITSymbolFlags::Exported };
symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym),
llvm::JITSymbolFlags::Exported};
return symbolMap;
});
}
Expand Down
20 changes: 12 additions & 8 deletions mlir/lib/ExecutionEngine/ExecutionEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void ExecutionEngine::dumpToObjectFile(StringRef filename) {
}
// Compilation is lazy and it doesn't populate object cache unless requested.
// In case object dump is requested before cache is populated, we need to
// force compilation manually.
// force compilation manually.
if (cache->isEmpty()) {
for (std::string &functionName : functionNames) {
auto result = lookupPacked(functionName);
Expand Down Expand Up @@ -400,13 +400,6 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
return symbolMap;
};
engine->registerSymbols(runtimeSymbolMap);

// Execute the global constructors from the module being processed.
// TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
// crash for AArch64 see related issue #71963.
if (!engine->jit->getTargetTriple().isAArch64())
cantFail(engine->jit->initialize(engine->jit->getMainJITDylib()));

return std::move(engine);
}

Expand Down Expand Up @@ -442,6 +435,7 @@ Expected<void *> ExecutionEngine::lookup(StringRef name) const {

Error ExecutionEngine::invokePacked(StringRef name,
MutableArrayRef<void *> args) {
initialize();
auto expectedFPtr = lookupPacked(name);
if (!expectedFPtr)
return expectedFPtr.takeError();
Expand All @@ -451,3 +445,13 @@ Error ExecutionEngine::invokePacked(StringRef name,

return Error::success();
}

void ExecutionEngine::initialize() {
if (isInitialized)
return;
// TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
// crash for AArch64 see related issue #71963.
if (!jit->getTargetTriple().isAArch64())
cantFail(jit->initialize(jit->getMainJITDylib()));
isInitialized = true;
}
2 changes: 2 additions & 0 deletions mlir/lib/ExecutionEngine/JitRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint,

auto engine = std::move(*expectedEngine);

engine->initialize();

auto expectedFPtr = engine->lookupPacked(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
Expand Down
1 change: 1 addition & 0 deletions mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ class ExecutionEngine:
def dump_to_object_file(self, file_name: str) -> None: ...
def raw_lookup(self, func_name: str) -> int: ...
def raw_register_runtime(self, name: str, callback: object) -> None: ...
def init() -> None: ...
@property
def _CAPIPtr(self) -> object: ...
7 changes: 7 additions & 0 deletions mlir/test/CAPI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ if(MLIR_ENABLE_EXECUTION_ENGINE)
MLIRCAPIConversion
MLIRCAPIExecutionEngine
MLIRCAPIRegisterEverything
)
_add_capi_test_executable(mlir-capi-global-constructors-test
global_constructors.c
LINK_LIBS PRIVATE
MLIRCAPIConversion
MLIRCAPIExecutionEngine
MLIRCAPIRegisterEverything
)
endif()

Expand Down
113 changes: 113 additions & 0 deletions mlir/test/CAPI/global_constructors.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
//===- global_constructors.c - Test JIT with the global constructors ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// UNSUPPORTED: target=aarch64{{.*}}, target=arm64{{.*}}
/* RUN: mlir-capi-global-constructors-test 2>&1 | FileCheck %s
*/
/* REQUIRES: host-supports-jit
*/

#include "mlir-c/Conversion.h"
#include "mlir-c/ExecutionEngine.h"
#include "mlir-c/IR.h"
#include "mlir-c/RegisterEverything.h"

#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

static void registerAllUpstreamDialects(MlirContext ctx) {
MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirRegisterAllDialects(registry);
mlirContextAppendDialectRegistry(ctx, registry);
mlirDialectRegistryDestroy(registry);
}

void lowerModuleToLLVM(MlirContext ctx, MlirModule module) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
MlirOpPassManager opm = mlirPassManagerGetNestedUnder(
pm, mlirStringRefCreateFromCString("func.func"));
mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVMPass());
mlirOpPassManagerAddOwnedPass(
opm, mlirCreateConversionArithToLLVMConversionPass());
MlirLogicalResult status =
mlirPassManagerRunOnOp(pm, mlirModuleGetOperation(module));
if (mlirLogicalResultIsFailure(status)) {
fprintf(stderr, "Unexpected failure running pass pipeline\n");
exit(2);
}
mlirPassManagerDestroy(pm);
}

// Helper variable to track callback invocations
static int initCnt = 0;

// Callback function that will be called during JIT initialization
static void initCallback(void) { initCnt += 1; }

// CHECK-LABEL: Running test 'testGlobalCtorJitCallback'
void testGlobalCtorJitCallback(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);

// Create module with global constructor that calls our callback
MlirModule module = mlirModuleCreateParse(
ctx, mlirStringRefCreateFromCString(
// clang-format off
"module { \n"
" llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero] \n"
" llvm.func @ctor() { \n"
" func.call @init_callback() : () -> () \n"
" llvm.return \n"
" } \n"
" func.func private @init_callback() attributes { llvm.emit_c_interface } \n"
"} \n"
// clang-format on
));

lowerModuleToLLVM(ctx, module);
mlirRegisterAllLLVMTranslations(ctx);

// Create execution engine with initialization disabled
MlirExecutionEngine jit = mlirExecutionEngineCreate(
module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL,
/*enableObjectDump=*/false);

if (mlirExecutionEngineIsNull(jit)) {
fprintf(stderr, "Execution engine creation failed");
exit(2);
}

// Register callback symbol before initialization
mlirExecutionEngineRegisterSymbol(
jit, mlirStringRefCreateFromCString("_mlir_ciface_init_callback"),
(void *)(uintptr_t)initCallback);

mlirExecutionEngineInitialize(jit);

// CHECK: Init count: 1
printf("Init count: %d\n", initCnt);

mlirExecutionEngineDestroy(jit);
mlirModuleDestroy(module);
mlirContextDestroy(ctx);
}

int main(void) {

#define _STRINGIFY(x) #x
#define STRINGIFY(x) _STRINGIFY(x)
#define TEST(test) \
printf("Running test '" STRINGIFY(test) "'\n"); \
test();
TEST(testGlobalCtorJitCallback);
return 0;
}
1 change: 1 addition & 0 deletions mlir/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ if(LLVM_ENABLE_PIC AND TARGET ${LLVM_NATIVE_ARCH})
llc
mlir_async_runtime
mlir-capi-execution-engine-test
mlir-capi-global-constructors-test
mlir_c_runner_utils
mlir_runner_utils
mlir_float16_utils
Expand Down
1 change: 1 addition & 0 deletions mlir/test/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def find_real_python_interpreter():
"mlir-translate",
"mlir-lsp-server",
"mlir-capi-execution-engine-test",
"mlir-capi-global-constructors-test",
"mlir-capi-ir-test",
"mlir-capi-irdl-test",
"mlir-capi-llvm-test",
Expand Down
72 changes: 72 additions & 0 deletions mlir/test/python/global_constructors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# UNSUPPORTED: target=aarch64{{.*}}, target=arm64{{.*}}
# RUN: %PYTHON %s 2>&1 | FileCheck %s
# REQUIRES: host-supports-jit
import gc, sys, os, tempfile
from mlir.ir import *
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir.runtime import *


# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
def log(*args):
print(*args, file=sys.stderr)
sys.stderr.flush()


def run(f):
log("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0


def lowerToLLVM(module):
pm = PassManager.parse(
"builtin.module(convert-func-to-llvm,reconcile-unrealized-casts)"
)
pm.run(module.operation)
return module


# Test JIT callback in global constructor
# CHECK-LABEL: TEST: testJITCallbackInGlobalCtor
def testJITCallbackInGlobalCtor():
init_cnt = 0

@ctypes.CFUNCTYPE(None)
def initCallback():
nonlocal init_cnt
init_cnt += 1

with Context():
module = Module.parse(
r"""
llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero]
llvm.func @ctor() {
func.call @init_callback() : () -> ()
llvm.return
}
func.func private @init_callback() attributes { llvm.emit_c_interface }
"""
)

# Setup execution engine
execution_engine = ExecutionEngine(lowerToLLVM(module))

# Validate initialization hasn't run yet
assert init_cnt == 0

# # Register callback
execution_engine.register_runtime("init_callback", initCallback)

# # Initialize and verify
execution_engine.initialize()
assert init_cnt == 1
# # Second initialization should be no-op
execution_engine.initialize()
assert init_cnt == 1


run(testJITCallbackInGlobalCtor)
Loading