Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions mlir/include/mlir-c/ExecutionEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@ DEFINE_C_API_STRUCT(MlirExecutionEngine, void);
/// that will be loaded are specified via `numPaths` and `sharedLibPaths`
/// respectively.
/// TODO: figure out other options.
MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(
MlirModule op, int optLevel, int numPaths,
const MlirStringRef *sharedLibPaths, bool enableObjectDump);
MLIR_CAPI_EXPORTED MlirExecutionEngine
mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
const MlirStringRef *sharedLibPaths,
bool enableObjectDump, bool shouldInitialize);

/// Execute the global constructors from the module.
MLIR_CAPI_EXPORTED void mlirExecutionEngineInitialize(MlirExecutionEngine jit);

/// Destroy an ExecutionEngine instance.
MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit);
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ struct ExecutionEngineOptions {
/// If `enablePerfNotificationListener` is set, the JIT compiler will notify
/// the llvm's global Perf notification listener.
bool enablePerfNotificationListener = true;

/// Setting initialize=false to defer initialization
bool shouldInitialize = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why we need a flag for this, seems to me that since we introduce a separate initialize() method we don't need an option and users can just call initialize() separately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why we need a flag for this, seems to me that since we introduce a separate initialize() method we don't need an option and users can just call initialize() separately.

It is introduced to reduce the changes downstream has to make.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not the usual criteria of development in LLVM though: we aim to keep the code here maintainable and preserve our ability to innovate. Downstream know they need to adapt when we find design issues we fix upstream.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we lazily initialize() when needed if the user hasn't called initialize()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we lazily initialize() when needed if the user hasn't called initialize()?

Good idea. The shouldInitialize flag is removed. Inside invokePacked we initialize() (noop if already initialized).

};

/// JIT-backed execution engine for MLIR. Assumes the IR can be converted to
Expand Down Expand Up @@ -227,6 +230,13 @@ class ExecutionEngine {
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
symbolMap);

/// Initialize the ExecutionEngine. Global constructors specified by
/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
/// binary compiled from `gpu.module` gets loaded during initialization. Make
/// sure all symbols are resolvable before initialization by calling
/// `registerSymbols` or including shared libraries.
void initialize();

private:
/// Ordering of llvmContext and jit is important for destruction purposes: the
/// jit must be destroyed before the context.
Expand All @@ -250,6 +260,8 @@ class ExecutionEngine {
/// Destroy functions in the libraries loaded by the ExecutionEngine that are
/// called when this ExecutionEngine is destructed.
SmallVector<LibraryDestroyFn> destroyFns;

bool isInitialized = false;
};

} // namespace mlir
Expand Down
17 changes: 12 additions & 5 deletions mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//

#include "mlir-c/ExecutionEngine.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace nb = nanobind;
using namespace mlir;
Expand Down Expand Up @@ -75,13 +75,13 @@ NB_MODULE(_mlirExecutionEngine, m) {
"__init__",
[](PyExecutionEngine &self, MlirModule module, int optLevel,
const std::vector<std::string> &sharedLibPaths,
bool enableObjectDump) {
bool enableObjectDump, bool shouldInitialize) {
llvm::SmallVector<MlirStringRef, 4> libPaths;
for (const std::string &path : sharedLibPaths)
libPaths.push_back({path.c_str(), path.length()});
MlirExecutionEngine executionEngine =
mlirExecutionEngineCreate(module, optLevel, libPaths.size(),
libPaths.data(), enableObjectDump);
MlirExecutionEngine executionEngine = mlirExecutionEngineCreate(
module, optLevel, libPaths.size(), libPaths.data(),
enableObjectDump, shouldInitialize);
if (mlirExecutionEngineIsNull(executionEngine))
throw std::runtime_error(
"Failure while creating the ExecutionEngine.");
Expand All @@ -90,6 +90,7 @@ NB_MODULE(_mlirExecutionEngine, m) {
nb::arg("module"), nb::arg("opt_level") = 2,
nb::arg("shared_libs") = nb::list(),
nb::arg("enable_object_dump") = true,
nb::arg("should_initialize") = true,
"Create a new ExecutionEngine instance for the given Module. The "
"module must contain only dialects that can be translated to LLVM. "
"Perform transformations and code generation at the optimization "
Expand Down Expand Up @@ -124,6 +125,12 @@ NB_MODULE(_mlirExecutionEngine, m) {
},
nb::arg("name"), nb::arg("callback"),
"Register `callback` as the runtime symbol `name`.")
.def(
"initialize",
[](PyExecutionEngine &executionEngine) {
mlirExecutionEngineInitialize(executionEngine.get());
},
"Initialize the ExecutionEngine.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand the doc? This doc isn't very useful as it just repeats the method name. We need to describe the expectations from the user point of view.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand the doc? This doc isn't very useful as it just repeats the method name. We need to describe the expectations from the user point of view.

Docs for C++, C, and Python has been version-synchronized, with added guidance for calling different 'register symbols' variants.

.def(
"dump_to_object_file",
[](PyExecutionEngine &executionEngine, const std::string &fileName) {
Expand Down
7 changes: 6 additions & 1 deletion mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using namespace mlir;
extern "C" MlirExecutionEngine
mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
const MlirStringRef *sharedLibPaths,
bool enableObjectDump) {
bool enableObjectDump, bool shouldInitialize) {
static bool initOnce = [] {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmParser(); // needed for inline_asm
Expand Down Expand Up @@ -60,6 +60,7 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
jitOptions.jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(optLevel);
jitOptions.sharedLibPaths = libPaths;
jitOptions.enableObjectDump = enableObjectDump;
jitOptions.shouldInitialize = shouldInitialize;
auto jitOrError = ExecutionEngine::create(unwrap(op), jitOptions);
if (!jitOrError) {
consumeError(jitOrError.takeError());
Expand All @@ -68,6 +69,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
return wrap(jitOrError->release());
}

extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) {
unwrap(jit)->initialize();
}

extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) {
delete (unwrap(jit));
}
Expand Down
21 changes: 13 additions & 8 deletions mlir/lib/ExecutionEngine/ExecutionEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void ExecutionEngine::dumpToObjectFile(StringRef filename) {
}
// Compilation is lazy and it doesn't populate object cache unless requested.
// In case object dump is requested before cache is populated, we need to
// force compilation manually.
// force compilation manually.
if (cache->isEmpty()) {
for (std::string &functionName : functionNames) {
auto result = lookupPacked(functionName);
Expand Down Expand Up @@ -400,13 +400,8 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
return symbolMap;
};
engine->registerSymbols(runtimeSymbolMap);

// Execute the global constructors from the module being processed.
// TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
// crash for AArch64 see related issue #71963.
if (!engine->jit->getTargetTriple().isAArch64())
cantFail(engine->jit->initialize(engine->jit->getMainJITDylib()));

if (options.shouldInitialize)
engine->initialize();
return std::move(engine);
}

Expand Down Expand Up @@ -451,3 +446,13 @@ Error ExecutionEngine::invokePacked(StringRef name,

return Error::success();
}

void ExecutionEngine::initialize() {
if (isInitialized)
return;
// TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
// crash for AArch64 see related issue #71963.
if (!jit->getTargetTriple().isAArch64())
cantFail(jit->initialize(jit->getMainJITDylib()));
isInitialized = true;
}
59 changes: 57 additions & 2 deletions mlir/test/CAPI/execution_engine.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void testSimpleExecution(void) {
mlirRegisterAllLLVMTranslations(ctx);
MlirExecutionEngine jit = mlirExecutionEngineCreate(
module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL,
/*enableObjectDump=*/false);
/*enableObjectDump=*/false, /*shouldInitialize=*/true);
if (mlirExecutionEngineIsNull(jit)) {
fprintf(stderr, "Execution engine creation failed");
exit(2);
Expand Down Expand Up @@ -125,7 +125,7 @@ void testOmpCreation(void) {
// against the OpenMP library.
MlirExecutionEngine jit = mlirExecutionEngineCreate(
module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL,
/*enableObjectDump=*/false);
/*enableObjectDump=*/false, /*shouldInitialize=*/true);
if (mlirExecutionEngineIsNull(jit)) {
fprintf(stderr, "Engine creation failed with OpenMP");
exit(2);
Expand All @@ -137,6 +137,60 @@ void testOmpCreation(void) {
mlirContextDestroy(ctx);
}

// Helper variable to track callback invocations
static int initCnt = 0;

// Callback function that will be called during JIT initialization
static void initCallback(void) { initCnt += 1; }

// CHECK-LABEL: Running test 'testGlobalCtorJitCallback'
void testGlobalCtorJitCallback(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);

// Create module with global constructor that calls our callback
MlirModule module = mlirModuleCreateParse(
ctx, mlirStringRefCreateFromCString(
// clang-format off
"module { \n"
" llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero] \n"
" llvm.func @ctor() { \n"
" func.call @init_callback() : () -> () \n"
" llvm.return \n"
" } \n"
" func.func private @init_callback() attributes { llvm.emit_c_interface } \n"
"} \n"
// clang-format on
));

lowerModuleToLLVM(ctx, module);
mlirRegisterAllLLVMTranslations(ctx);

// Create execution engine with initialization disabled
MlirExecutionEngine jit = mlirExecutionEngineCreate(
module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL,
/*enableObjectDump=*/false, /*shouldInitialize=*/false);

if (mlirExecutionEngineIsNull(jit)) {
fprintf(stderr, "Execution engine creation failed");
exit(2);
}

// Register callback symbol before initialization
mlirExecutionEngineRegisterSymbol(
jit, mlirStringRefCreateFromCString("_mlir_ciface_init_callback"),
(void *)(uintptr_t)initCallback);

mlirExecutionEngineInitialize(jit);

// CHECK: Init count: 1
printf("Init count: %d\n", initCnt);

mlirExecutionEngineDestroy(jit);
mlirModuleDestroy(module);
mlirContextDestroy(ctx);
}

int main(void) {

#define _STRINGIFY(x) #x
Expand All @@ -147,5 +201,6 @@ int main(void) {

TEST(testSimpleExecution);
TEST(testOmpCreation);
TEST(testGlobalCtorJitCallback);
return 0;
}
40 changes: 40 additions & 0 deletions mlir/unittests/ExecutionEngine/Invoke.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,44 @@ TEST(NativeMemRefJit, MAYBE_JITCallback) {
ASSERT_EQ(elt, coefficient * count++);
}

static int initCnt = 0;
// A helper function that will be called during the JIT's initialization.
static void initCallback() { initCnt += 1; }

TEST(GlobalCtorJit, MAYBE_JITCallback) {
std::string moduleStr = R"mlir(
llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero]
llvm.func @ctor() {
func.call @init_callback() : () -> ()
llvm.return
}
func.func private @init_callback() attributes { llvm.emit_c_interface }
)mlir";

DialectRegistry registry;
registerAllDialects(registry);
registerBuiltinDialectTranslation(registry);
registerLLVMDialectTranslation(registry);
MLIRContext context(registry);
auto module = parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
ExecutionEngineOptions jitOptions;
// Defer the initialization to register symbols used in ctors.
jitOptions.shouldInitialize = false;
auto jitOrError = ExecutionEngine::create(*module, jitOptions);
ASSERT_TRUE(!!jitOrError);
auto jit = std::move(jitOrError.get());
// Define any extra symbols so they're available at initialization.
jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
llvm::orc::SymbolMap symbolMap;
symbolMap[interner("_mlir_ciface_init_callback")] = {
llvm::orc::ExecutorAddr::fromPtr(initCallback),
llvm::JITSymbolFlags::Exported};
return symbolMap;
});
jit->initialize();
ASSERT_EQ(initCnt, 1);
}

#endif // _WIN32