-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][gpu] Propagate errors from ModuleToObject callbacks
#170134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-llvm Author: Ivan Butygin (Hardcode84) ChangesInitial discussion #170016 (comment) While the initial PR is using these callbacks for debug printing, and filesystem failures are not directly related to this code logic, I can envision passes using these for IR validation and/or module pre/postprocessing which can legitimate fail. Full diff: https://github.com/llvm/llvm-project/pull/170134.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 139360f8bd3fc..00f885898ffa1 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -55,10 +55,10 @@ class TargetOptions {
StringRef cmdOptions = {}, StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
- function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<LogicalResult(StringRef)> isaCallback = {});
/// Returns the typeID.
TypeID getTypeID() const;
@@ -97,19 +97,20 @@ class TargetOptions {
/// Returns the callback invoked with the initial LLVM IR for the device
/// module.
- function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
+ function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const;
/// Returns the callback invoked with LLVM IR for the device module
/// after linking the device libraries.
- function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
+ function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const;
/// Returns the callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
+ function_ref<LogicalResult(llvm::Module &)>
+ getOptimizedLlvmIRCallback() const;
/// Returns the callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<void(StringRef)> getISACallback() const;
+ function_ref<LogicalResult(StringRef)> getISACallback() const;
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
@@ -127,10 +128,10 @@ class TargetOptions {
StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
- function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<LogicalResult(StringRef)> isaCallback = {});
/// Path to the target toolkit.
std::string toolkitPath;
@@ -153,19 +154,19 @@ class TargetOptions {
function_ref<SymbolTable *()> getSymbolTableCallback;
/// Callback invoked with the initial LLVM IR for the device module.
- function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<void(StringRef)> isaCallback;
+ function_ref<LogicalResult(StringRef)> isaCallback;
private:
TypeID typeID;
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index 11fea6f0a4443..0edc20cd32620 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -32,10 +32,10 @@ class ModuleToObject {
ModuleToObject(
Operation &module, StringRef triple, StringRef chip,
StringRef features = {}, int optLevel = 3,
- function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<LogicalResult(StringRef)> isaCallback = {});
virtual ~ModuleToObject();
/// Returns the operation being serialized.
@@ -120,19 +120,19 @@ class ModuleToObject {
int optLevel;
/// Callback invoked with the initial LLVM IR for the device module.
- function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<void(StringRef)> isaCallback;
+ function_ref<LogicalResult(StringRef)> isaCallback;
private:
/// The TargetMachine created for the given Triple, if available.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 6c6d8d2bad55d..a813608fdf209 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2652,10 +2652,10 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
- function_ref<void(llvm::Module &)> initialLlvmIRCallback,
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<LogicalResult(StringRef)> isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
cmdOptions, elfSection, compilationTarget,
getSymbolTableCallback, initialLlvmIRCallback,
@@ -2667,10 +2667,10 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
- function_ref<void(llvm::Module &)> initialLlvmIRCallback,
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<LogicalResult(StringRef)> isaCallback)
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
compilationTarget(compilationTarget),
@@ -2696,22 +2696,22 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
TargetOptions::getInitialLlvmIRCallback() const {
return initialLlvmIRCallback;
}
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
TargetOptions::getLinkedLlvmIRCallback() const {
return linkedLlvmIRCallback;
}
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
TargetOptions::getOptimizedLlvmIRCallback() const {
return optimizedLlvmIRCallback;
}
-function_ref<void(StringRef)> TargetOptions::getISACallback() const {
+function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const {
return isaCallback;
}
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 4098ccc548dc1..d881dda69453b 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -36,10 +36,11 @@ using namespace mlir::LLVM;
ModuleToObject::ModuleToObject(
Operation &module, StringRef triple, StringRef chip, StringRef features,
- int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ int optLevel,
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<LogicalResult(StringRef)> isaCallback)
: module(module), triple(triple), chip(chip), features(features),
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
@@ -254,8 +255,12 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
setDataLayoutAndTriple(*llvmModule);
- if (initialLlvmIRCallback)
- initialLlvmIRCallback(*llvmModule);
+ if (initialLlvmIRCallback) {
+ if (failed(initialLlvmIRCallback(*llvmModule))) {
+ getOperation().emitError() << "InitialLLVMIRCallback failed.";
+ return std::nullopt;
+ }
+ }
// Link bitcode files.
handleModulePreLink(*llvmModule);
@@ -269,15 +274,23 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
handleModulePostLink(*llvmModule);
}
- if (linkedLlvmIRCallback)
- linkedLlvmIRCallback(*llvmModule);
+ if (linkedLlvmIRCallback) {
+ if (failed(linkedLlvmIRCallback(*llvmModule))) {
+ getOperation().emitError() << "LinkedLLVMIRCallback failed.";
+ return std::nullopt;
+ }
+ }
// Optimize the module.
if (failed(optimizeModule(*llvmModule, optLevel)))
return std::nullopt;
- if (optimizedLlvmIRCallback)
- optimizedLlvmIRCallback(*llvmModule);
+ if (optimizedLlvmIRCallback) {
+ if (failed(optimizedLlvmIRCallback(*llvmModule))) {
+ getOperation().emitError() << "OptimizedLLVMIRCallback failed.";
+ return std::nullopt;
+ }
+ }
// Return the serialized object.
return moduleToObject(*llvmModule);
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 8760ea8588e2c..cbd6a6d878813 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -707,8 +707,12 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
return std::nullopt;
}
- if (isaCallback)
- isaCallback(serializedISA.value());
+ if (isaCallback) {
+ if (failed(isaCallback(serializedISA.value()))) {
+ getOperation().emitError() << "ISACallback failed.";
+ return std::nullopt;
+ }
+ }
#define DEBUG_TYPE "serialize-to-isa"
LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index af0af89c7d07e..1692c4490e4d1 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -176,26 +176,32 @@ TEST_F(MLIRTargetLLVMNVVM,
ASSERT_TRUE(!!serializer);
std::string initialLLVMIR;
- auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+ auto initialCallback =
+ [&initialLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
+ return success();
};
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+ auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
std::string optimizedLLVMIR;
- auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+ auto optimizedCallback =
+ [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
std::string isaResult;
- auto isaCallback = [&isaResult](llvm::StringRef isa) {
+ auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult {
isaResult = isa.str();
+ return success();
};
gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
@@ -220,6 +226,34 @@ TEST_F(MLIRTargetLLVMNVVM,
}
}
+// Test callback functions failure with ISA.
+TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) {
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+
+ NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
+
+ auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+ ASSERT_TRUE(!!serializer);
+
+ auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult {
+ return failure();
+ };
+
+ gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
+ {}, {}, {}, {}, isaCallback);
+
+ for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+ std::optional<SmallVector<char, 0>> object =
+ serializer.serializeToObject(gpuModule, options);
+
+ ASSERT_TRUE(object == std::nullopt);
+ }
+}
+
// Test linking LLVM IR from a resource attribute.
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
MLIRContext context(registry);
@@ -261,9 +295,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
// Hook to intercept the LLVM IR after linking external libs.
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+ auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
// Store the bitcode as a DenseI8ArrayAttr.
diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
index 3c880edee4ffc..b392065132787 100644
--- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
@@ -168,9 +168,11 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) {
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
std::string initialLLVMIR;
- auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+ auto initialCallback =
+ [&initialLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
+ return success();
};
gpu::TargetOptions opts(
@@ -196,9 +198,10 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) {
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+ auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
gpu::TargetOptions opts(
@@ -225,9 +228,11 @@ TEST_F(MLIRTargetLLVM,
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
std::string optimizedLLVMIR;
- auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+ auto optimizedCallback =
+ [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
gpu::TargetOptions opts(
@@ -240,3 +245,75 @@ TEST_F(MLIRTargetLLVM,
ASSERT_TRUE(!serializedBinary->empty());
ASSERT_TRUE(!optimizedLLVMIR.empty());
}
+
+// Test callback function failure with initial LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) {
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ Builder builder(&context);
+ IntegerAttr target = builder.getI32IntegerAttr(0);
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+ auto initialCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+ return failure();
+ };
+
+ gpu::TargetOptions opts(
+ {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+ {}, initialCallback);
+ std::optional<SmallVector<char, 0>> serializedBinary =
+ targetAttr.serializeToObject(*module, opts);
+
+ ASSERT_TRUE(serializedBinary == std::nullopt);
+}
+
+// Test callback function failure with linked LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) {
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ Builder builder(&context);
+ IntegerAttr target = builder.getI32IntegerAttr(0);
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+ auto linkedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+ return failure();
+ };
+
+ gpu::TargetOptions opts(
+ {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+ {}, {}, linkedCallback);
+ std::optional<SmallVector<char, 0>> serializedBinary =
+ targetAttr.serializeToObject(*module, opts);
+
+ ASSERT_TRUE(serializedBinary == std::nullopt);
+}
+
+// Test callback function failure with optimized LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) {
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ Builder builder(&context);
+ IntegerAttr target = builder.getI32IntegerAttr(0);
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+ auto optimizedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+ return failure();
+ };
+
+ gpu::TargetOptions opts(
+ {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+ {}, {}, {}, optimizedCallback);
+ std::optional<SmallVector<char, 0>> serializedBinary =
+ targetAttr.serializeToObject(*module, opts);
+
+ ASSERT_TRUE(serializedBinary == std::nullopt);
+}
|
|
@llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesInitial discussion #170016 (comment) While the initial PR is using these callbacks for debug printing, and filesystem failures are not directly related to this code logic, I can envision passes using these for IR validation and/or module pre/postprocessing which can legitimate fail. Full diff: https://github.com/llvm/llvm-project/pull/170134.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 139360f8bd3fc..00f885898ffa1 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -55,10 +55,10 @@ class TargetOptions {
StringRef cmdOptions = {}, StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
- function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<LogicalResult(StringRef)> isaCallback = {});
/// Returns the typeID.
TypeID getTypeID() const;
@@ -97,19 +97,20 @@ class TargetOptions {
/// Returns the callback invoked with the initial LLVM IR for the device
/// module.
- function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
+ function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const;
/// Returns the callback invoked with LLVM IR for the device module
/// after linking the device libraries.
- function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
+ function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const;
/// Returns the callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
+ function_ref<LogicalResult(llvm::Module &)>
+ getOptimizedLlvmIRCallback() const;
/// Returns the callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<void(StringRef)> getISACallback() const;
+ function_ref<LogicalResult(StringRef)> getISACallback() const;
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
@@ -127,10 +128,10 @@ class TargetOptions {
StringRef elfSection = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {},
- function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<LogicalResult(StringRef)> isaCallback = {});
/// Path to the target toolkit.
std::string toolkitPath;
@@ -153,19 +154,19 @@ class TargetOptions {
function_ref<SymbolTable *()> getSymbolTableCallback;
/// Callback invoked with the initial LLVM IR for the device module.
- function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<void(StringRef)> isaCallback;
+ function_ref<LogicalResult(StringRef)> isaCallback;
private:
TypeID typeID;
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index 11fea6f0a4443..0edc20cd32620 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -32,10 +32,10 @@ class ModuleToObject {
ModuleToObject(
Operation &module, StringRef triple, StringRef chip,
StringRef features = {}, int optLevel = 3,
- function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
- function_ref<void(StringRef)> isaCallback = {});
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<LogicalResult(StringRef)> isaCallback = {});
virtual ~ModuleToObject();
/// Returns the operation being serialized.
@@ -120,19 +120,19 @@ class ModuleToObject {
int optLevel;
/// Callback invoked with the initial LLVM IR for the device module.
- function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
- function_ref<void(StringRef)> isaCallback;
+ function_ref<LogicalResult(StringRef)> isaCallback;
private:
/// The TargetMachine created for the given Triple, if available.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 6c6d8d2bad55d..a813608fdf209 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2652,10 +2652,10 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
- function_ref<void(llvm::Module &)> initialLlvmIRCallback,
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<LogicalResult(StringRef)> isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
cmdOptions, elfSection, compilationTarget,
getSymbolTableCallback, initialLlvmIRCallback,
@@ -2667,10 +2667,10 @@ TargetOptions::TargetOptions(
StringRef cmdOptions, StringRef elfSection,
CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback,
- function_ref<void(llvm::Module &)> initialLlvmIRCallback,
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<LogicalResult(StringRef)> isaCallback)
: toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
compilationTarget(compilationTarget),
@@ -2696,22 +2696,22 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
TargetOptions::getInitialLlvmIRCallback() const {
return initialLlvmIRCallback;
}
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
TargetOptions::getLinkedLlvmIRCallback() const {
return linkedLlvmIRCallback;
}
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
TargetOptions::getOptimizedLlvmIRCallback() const {
return optimizedLlvmIRCallback;
}
-function_ref<void(StringRef)> TargetOptions::getISACallback() const {
+function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const {
return isaCallback;
}
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 4098ccc548dc1..d881dda69453b 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -36,10 +36,11 @@ using namespace mlir::LLVM;
ModuleToObject::ModuleToObject(
Operation &module, StringRef triple, StringRef chip, StringRef features,
- int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
- function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
- function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
- function_ref<void(StringRef)> isaCallback)
+ int optLevel,
+ function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<LogicalResult(StringRef)> isaCallback)
: module(module), triple(triple), chip(chip), features(features),
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
@@ -254,8 +255,12 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
setDataLayoutAndTriple(*llvmModule);
- if (initialLlvmIRCallback)
- initialLlvmIRCallback(*llvmModule);
+ if (initialLlvmIRCallback) {
+ if (failed(initialLlvmIRCallback(*llvmModule))) {
+ getOperation().emitError() << "InitialLLVMIRCallback failed.";
+ return std::nullopt;
+ }
+ }
// Link bitcode files.
handleModulePreLink(*llvmModule);
@@ -269,15 +274,23 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
handleModulePostLink(*llvmModule);
}
- if (linkedLlvmIRCallback)
- linkedLlvmIRCallback(*llvmModule);
+ if (linkedLlvmIRCallback) {
+ if (failed(linkedLlvmIRCallback(*llvmModule))) {
+ getOperation().emitError() << "LinkedLLVMIRCallback failed.";
+ return std::nullopt;
+ }
+ }
// Optimize the module.
if (failed(optimizeModule(*llvmModule, optLevel)))
return std::nullopt;
- if (optimizedLlvmIRCallback)
- optimizedLlvmIRCallback(*llvmModule);
+ if (optimizedLlvmIRCallback) {
+ if (failed(optimizedLlvmIRCallback(*llvmModule))) {
+ getOperation().emitError() << "OptimizedLLVMIRCallback failed.";
+ return std::nullopt;
+ }
+ }
// Return the serialized object.
return moduleToObject(*llvmModule);
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 8760ea8588e2c..cbd6a6d878813 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -707,8 +707,12 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
return std::nullopt;
}
- if (isaCallback)
- isaCallback(serializedISA.value());
+ if (isaCallback) {
+ if (failed(isaCallback(serializedISA.value()))) {
+ getOperation().emitError() << "ISACallback failed.";
+ return std::nullopt;
+ }
+ }
#define DEBUG_TYPE "serialize-to-isa"
LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index af0af89c7d07e..1692c4490e4d1 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -176,26 +176,32 @@ TEST_F(MLIRTargetLLVMNVVM,
ASSERT_TRUE(!!serializer);
std::string initialLLVMIR;
- auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+ auto initialCallback =
+ [&initialLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
+ return success();
};
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+ auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
std::string optimizedLLVMIR;
- auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+ auto optimizedCallback =
+ [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
std::string isaResult;
- auto isaCallback = [&isaResult](llvm::StringRef isa) {
+ auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult {
isaResult = isa.str();
+ return success();
};
gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
@@ -220,6 +226,34 @@ TEST_F(MLIRTargetLLVMNVVM,
}
}
+// Test callback functions failure with ISA.
+TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) {
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+
+ NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
+
+ auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+ ASSERT_TRUE(!!serializer);
+
+ auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult {
+ return failure();
+ };
+
+ gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
+ {}, {}, {}, {}, isaCallback);
+
+ for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+ std::optional<SmallVector<char, 0>> object =
+ serializer.serializeToObject(gpuModule, options);
+
+ ASSERT_TRUE(object == std::nullopt);
+ }
+}
+
// Test linking LLVM IR from a resource attribute.
TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
MLIRContext context(registry);
@@ -261,9 +295,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
// Hook to intercept the LLVM IR after linking external libs.
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+ auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
// Store the bitcode as a DenseI8ArrayAttr.
diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
index 3c880edee4ffc..b392065132787 100644
--- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
@@ -168,9 +168,11 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) {
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
std::string initialLLVMIR;
- auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+ auto initialCallback =
+ [&initialLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
+ return success();
};
gpu::TargetOptions opts(
@@ -196,9 +198,10 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) {
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
std::string linkedLLVMIR;
- auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+ auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
gpu::TargetOptions opts(
@@ -225,9 +228,11 @@ TEST_F(MLIRTargetLLVM,
auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
std::string optimizedLLVMIR;
- auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+ auto optimizedCallback =
+ [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
+ return success();
};
gpu::TargetOptions opts(
@@ -240,3 +245,75 @@ TEST_F(MLIRTargetLLVM,
ASSERT_TRUE(!serializedBinary->empty());
ASSERT_TRUE(!optimizedLLVMIR.empty());
}
+
+// Test callback function failure with initial LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) {
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ Builder builder(&context);
+ IntegerAttr target = builder.getI32IntegerAttr(0);
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+ auto initialCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+ return failure();
+ };
+
+ gpu::TargetOptions opts(
+ {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+ {}, initialCallback);
+ std::optional<SmallVector<char, 0>> serializedBinary =
+ targetAttr.serializeToObject(*module, opts);
+
+ ASSERT_TRUE(serializedBinary == std::nullopt);
+}
+
+// Test callback function failure with linked LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) {
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ Builder builder(&context);
+ IntegerAttr target = builder.getI32IntegerAttr(0);
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+ auto linkedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+ return failure();
+ };
+
+ gpu::TargetOptions opts(
+ {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+ {}, {}, linkedCallback);
+ std::optional<SmallVector<char, 0>> serializedBinary =
+ targetAttr.serializeToObject(*module, opts);
+
+ ASSERT_TRUE(serializedBinary == std::nullopt);
+}
+
+// Test callback function failure with optimized LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) {
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ Builder builder(&context);
+ IntegerAttr target = builder.getI32IntegerAttr(0);
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+ auto optimizedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+ return failure();
+ };
+
+ gpu::TargetOptions opts(
+ {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+ {}, {}, {}, optimizedCallback);
+ std::optional<SmallVector<char, 0>> serializedBinary =
+ targetAttr.serializeToObject(*module, opts);
+
+ ASSERT_TRUE(serializedBinary == std::nullopt);
+}
|
fabianmcg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general LGTM, the only change I would make is:
- Passing a Location as well to the callbacks
- Let the callback handle the error message, so no
getOperation().emitError() << "ISACallback failed.";
I'd argue that this is already a wanted change, as I could have a custom out of tree LLVM pass running at any of the LLVM callback levels that could fail.
joker-eph
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a diagnostic issue here: the callbacks don't have any information needed to diagnose (no MLIRContext, no operation to attach the diagnostic to...).
The use-case for this looks quite hand-wavy to me to justify adding more complexity to correctly handle all this though.
I can add source op arg to all the callbacks.
The one potential use case I considered is something similar to Triton |
| linkedLlvmIRCallback = {}, | ||
| function_ref<LogicalResult(Operation *op, llvm::Module &)> | ||
| optimizedLlvmIRCallback = {}, | ||
| function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now this opens up the possibility for the callbacks to modify the MLIR IR, which is a pretty wide opening of the API, I don't think this is what this API is intended to support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I can do it though the error callbacks like I did before https://github.com/llvm/llvm-project/blob/main/mlir/lib/Transforms/CompositePass.cpp#L45
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature is strange to me for a errorHandler. Look at the usual pattern in MLIR: https://github.com/search?q=repo%3Allvm%2Fllvm-project%20%22%3E%20emitError%22&type=code
function_ref<InFlightDiagnostic()> emitError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched to InFlightDiagnostic
This can already be done outside of ModuleToObject the way you've done it in the pass to signalPassFailure, no need to add the complexity here for this. |
|
In general if there is a need to fine grain control the logic across the steps that are encapsulated, then I rather avoid the complexity of the many handshakes with callbacks, and expose the individual steps to allow the users to stitch them however they want. This is a fairly usual approach to modularity in MLIR: a nested grouping of components where you can skip every high level of API to reach to a lower-level API to customize the assembly. The higher-level APIs are conveniently exposed but not the only entry points. |
|
If we want to override llvm IR during linking/optimizations phases and/or override device assembly, we will have to do it in those callbacks and we will have callbacks to return errors. |
I don't believe this is needed: proof of existence if what you're doing in the other PR to catch the error. I'll reiterate that the callback design is creating a monolithic component, which is not modular. We went with this only because the callbacks were intended as "debugging" things that don't affect the behavior of ModuleToObject significantly. If we need to provide more flexibility to users, the more modular way to achieve it is by disaggregating the component into pieces that the user can assemble, instead of making the "monolith" more complex with complex config/callbacks mechanism. Hence my objection to the current direction here. |
Initial discussion #170016 (comment)
While the initial PR is using these callbacks for debug printing, and filesystem failures are not directly related to this code logic, I can envision passes using these for IR validation and/or module pre/postprocessing which can legitimate fail.