Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,13 @@ class TargetOptions {
StringRef cmdOptions = {}, StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<void(StringRef)> isaCallback = {});
function_ref<LogicalResult(Operation *op, llvm::Module &)>
initialLlvmIRCallback = {},
function_ref<LogicalResult(Operation *op, llvm::Module &)>
linkedLlvmIRCallback = {},
function_ref<LogicalResult(Operation *op, llvm::Module &)>
optimizedLlvmIRCallback = {},
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this opens up the possibility for the callbacks to modify the MLIR IR, which is a pretty wide opening of the API, I don't think this is what this API is intended to support.

Copy link
Contributor Author

@Hardcode84 Hardcode84 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I can do it though the error callbacks like I did before https://github.com/llvm/llvm-project/blob/main/mlir/lib/Transforms/CompositePass.cpp#L45

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature is strange to me for a errorHandler. Look at the usual pattern in MLIR: https://github.com/search?q=repo%3Allvm%2Fllvm-project%20%22%3E%20emitError%22&type=code

 function_ref<InFlightDiagnostic()> emitError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to InFlightDiagnostic


/// Returns the typeID.
TypeID getTypeID() const;
Expand Down Expand Up @@ -97,19 +100,22 @@ class TargetOptions {

/// Returns the callback invoked with the initial LLVM IR for the device
/// module.
function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
function_ref<LogicalResult(Operation *op, llvm::Module &)>
getInitialLlvmIRCallback() const;

/// Returns the callback invoked with LLVM IR for the device module
/// after linking the device libraries.
function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
function_ref<LogicalResult(Operation *op, llvm::Module &)>
getLinkedLlvmIRCallback() const;

/// Returns the callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
function_ref<LogicalResult(Operation *op, llvm::Module &)>
getOptimizedLlvmIRCallback() const;

/// Returns the callback invoked with the target ISA for the device,
/// for example PTX assembly.
function_ref<void(StringRef)> getISACallback() const;
function_ref<LogicalResult(Operation *op, StringRef)> getISACallback() const;

/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
Expand All @@ -127,10 +133,13 @@ class TargetOptions {
StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<void(StringRef)> isaCallback = {});
function_ref<LogicalResult(Operation *op, llvm::Module &)>
initialLlvmIRCallback = {},
function_ref<LogicalResult(Operation *op, llvm::Module &)>
linkedLlvmIRCallback = {},
function_ref<LogicalResult(Operation *op, llvm::Module &)>
optimizedLlvmIRCallback = {},
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});

/// Path to the target toolkit.
std::string toolkitPath;
Expand All @@ -153,19 +162,22 @@ class TargetOptions {
function_ref<SymbolTable *()> getSymbolTableCallback;

/// Callback invoked with the initial LLVM IR for the device module.
function_ref<void(llvm::Module &)> initialLlvmIRCallback;
function_ref<LogicalResult(Operation *op, llvm::Module &)>
initialLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
function_ref<LogicalResult(Operation *op, llvm::Module &)>
linkedLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
function_ref<LogicalResult(Operation *op, llvm::Module &)>
optimizedLlvmIRCallback;

/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
function_ref<void(StringRef)> isaCallback;
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback;

private:
TypeID typeID;
Expand Down
22 changes: 14 additions & 8 deletions mlir/include/mlir/Target/LLVM/ModuleToObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ class ModuleToObject {
ModuleToObject(
Operation &module, StringRef triple, StringRef chip,
StringRef features = {}, int optLevel = 3,
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<void(StringRef)> isaCallback = {});
function_ref<LogicalResult(Operation *op, llvm::Module &)>
initialLlvmIRCallback = {},
function_ref<LogicalResult(Operation *op, llvm::Module &)>
linkedLlvmIRCallback = {},
function_ref<LogicalResult(Operation *op, llvm::Module &)>
optimizedLlvmIRCallback = {},
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});
virtual ~ModuleToObject();

/// Returns the operation being serialized.
Expand Down Expand Up @@ -120,19 +123,22 @@ class ModuleToObject {
int optLevel;

/// Callback invoked with the initial LLVM IR for the device module.
function_ref<void(llvm::Module &)> initialLlvmIRCallback;
function_ref<LogicalResult(Operation *op, llvm::Module &)>
initialLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
function_ref<LogicalResult(Operation *op, llvm::Module &)>
linkedLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
function_ref<LogicalResult(Operation *op, llvm::Module &)>
optimizedLlvmIRCallback;

/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
function_ref<void(StringRef)> isaCallback;
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback;

private:
/// The TargetMachine created for the given Triple, if available.
Expand Down
31 changes: 19 additions & 12 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2652,10 +2652,13 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
function_ref<LogicalResult(Operation *op, llvm::Module &)>
initialLlvmIRCallback,
function_ref<LogicalResult(Operation *op, llvm::Module &)>
linkedLlvmIRCallback,
function_ref<LogicalResult(Operation *op, llvm::Module &)>
optimizedLlvmIRCallback,
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
cmdOptions, elfSection, compilationTarget,
getSymbolTableCallback, initialLlvmIRCallback,
Expand All @@ -2667,10 +2670,13 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
function_ref<LogicalResult(Operation *op, llvm::Module &)>
initialLlvmIRCallback,
function_ref<LogicalResult(Operation *op, llvm::Module &)>
linkedLlvmIRCallback,
function_ref<LogicalResult(Operation *op, llvm::Module &)>
optimizedLlvmIRCallback,
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback)
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
compilationTarget(compilationTarget),
Expand All @@ -2696,22 +2702,23 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}

function_ref<void(llvm::Module &)>
function_ref<LogicalResult(Operation *op, llvm::Module &)>
TargetOptions::getInitialLlvmIRCallback() const {
return initialLlvmIRCallback;
}

function_ref<void(llvm::Module &)>
function_ref<LogicalResult(Operation *op, llvm::Module &)>
TargetOptions::getLinkedLlvmIRCallback() const {
return linkedLlvmIRCallback;
}

function_ref<void(llvm::Module &)>
function_ref<LogicalResult(Operation *op, llvm::Module &)>
TargetOptions::getOptimizedLlvmIRCallback() const {
return optimizedLlvmIRCallback;
}

function_ref<void(StringRef)> TargetOptions::getISACallback() const {
function_ref<LogicalResult(Operation *op, StringRef)>
TargetOptions::getISACallback() const {
return isaCallback;
}

Expand Down
21 changes: 14 additions & 7 deletions mlir/lib/Target/LLVM/ModuleToObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@ using namespace mlir::LLVM;

ModuleToObject::ModuleToObject(
Operation &module, StringRef triple, StringRef chip, StringRef features,
int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
int optLevel,
function_ref<LogicalResult(Operation *op, llvm::Module &)>
initialLlvmIRCallback,
function_ref<LogicalResult(Operation *op, llvm::Module &)>
linkedLlvmIRCallback,
function_ref<LogicalResult(Operation *op, llvm::Module &)>
optimizedLlvmIRCallback,
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback)
: module(module), triple(triple), chip(chip), features(features),
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
Expand Down Expand Up @@ -255,7 +259,8 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
setDataLayoutAndTriple(*llvmModule);

if (initialLlvmIRCallback)
initialLlvmIRCallback(*llvmModule);
if (failed(initialLlvmIRCallback(&getOperation(), *llvmModule)))
return std::nullopt;

// Link bitcode files.
handleModulePreLink(*llvmModule);
Expand All @@ -270,14 +275,16 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}

if (linkedLlvmIRCallback)
linkedLlvmIRCallback(*llvmModule);
if (failed(linkedLlvmIRCallback(&getOperation(), *llvmModule)))
return std::nullopt;

// Optimize the module.
if (failed(optimizeModule(*llvmModule, optLevel)))
return std::nullopt;

if (optimizedLlvmIRCallback)
optimizedLlvmIRCallback(*llvmModule);
if (failed(optimizedLlvmIRCallback(&getOperation(), *llvmModule)))
return std::nullopt;

// Return the serialized object.
return moduleToObject(*llvmModule);
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Target/LLVM/NVVM/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,8 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
}

if (isaCallback)
isaCallback(serializedISA.value());
if (failed(isaCallback(getOperation(), serializedISA.value())))
return std::nullopt;

#define DEBUG_TYPE "serialize-to-isa"
LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"
Expand Down
51 changes: 46 additions & 5 deletions mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,26 +176,36 @@ TEST_F(MLIRTargetLLVMNVVM,
ASSERT_TRUE(!!serializer);

std::string initialLLVMIR;
auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
auto initialCallback =
[&initialLLVMIR](Operation * /*op*/,
llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
return success();
};

std::string linkedLLVMIR;
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
auto linkedCallback = [&linkedLLVMIR](Operation * /*op*/,
llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
return success();
};

std::string optimizedLLVMIR;
auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
auto optimizedCallback =
[&optimizedLLVMIR](Operation * /*op*/,
llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
return success();
};

std::string isaResult;
auto isaCallback = [&isaResult](llvm::StringRef isa) {
auto isaCallback = [&isaResult](Operation * /*op*/,
llvm::StringRef isa) -> LogicalResult {
isaResult = isa.str();
return success();
};

gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
Expand All @@ -220,6 +230,35 @@ TEST_F(MLIRTargetLLVMNVVM,
}
}

// Test callback functions failure with ISA.
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) {
MLIRContext context(registry);

OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);

NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);

auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
ASSERT_TRUE(!!serializer);

auto isaCallback = [](Operation * /*op*/,
llvm::StringRef /*isa*/) -> LogicalResult {
return failure();
};

gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
{}, {}, {}, {}, isaCallback);

for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
std::optional<SmallVector<char, 0>> object =
serializer.serializeToObject(gpuModule, options);

ASSERT_TRUE(object == std::nullopt);
}
}

// Test linking LLVM IR from a resource attribute.
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
MLIRContext context(registry);
Expand Down Expand Up @@ -261,9 +300,11 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {

// Hook to intercept the LLVM IR after linking external libs.
std::string linkedLLVMIR;
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
auto linkedCallback = [&linkedLLVMIR](Operation * /*op*/,
llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
return success();
};

// Store the bitcode as a DenseI8ArrayAttr.
Expand Down
Loading