Skip to content

Conversation

@Hardcode84
Copy link
Contributor

Initial discussion #170016 (comment)

While the initial PR is using these callbacks for debug printing, and filesystem failures are not directly related to this code logic, I can envision passes using these for IR validation and/or module pre/postprocessing which can legitimate fail.

@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-llvm

Author: Ivan Butygin (Hardcode84)

Changes

Initial discussion #170016 (comment)

While the initial PR is using these callbacks for debug printing, and filesystem failures are not directly related to this code logic, I can envision passes using these for IR validation and/or module pre/postprocessing which can legitimate fail.


Full diff: https://github.com/llvm/llvm-project/pull/170134.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h (+17-16)
  • (modified) mlir/include/mlir/Target/LLVM/ModuleToObject.h (+8-8)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+12-12)
  • (modified) mlir/lib/Target/LLVM/ModuleToObject.cpp (+23-10)
  • (modified) mlir/lib/Target/LLVM/NVVM/Target.cpp (+6-2)
  • (modified) mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp (+40-5)
  • (modified) mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp (+80-3)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 139360f8bd3fc..00f885898ffa1 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -55,10 +55,10 @@ class TargetOptions {
       StringRef cmdOptions = {}, StringRef elfSection = {},
       CompilationTarget compilationTarget = getDefaultCompilationTarget(),
       function_ref<SymbolTable *()> getSymbolTableCallback = {},
-      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
-      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
-      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
-      function_ref<void(StringRef)> isaCallback = {});
+      function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<LogicalResult(StringRef)> isaCallback = {});
 
   /// Returns the typeID.
   TypeID getTypeID() const;
@@ -97,19 +97,20 @@ class TargetOptions {
 
   /// Returns the callback invoked with the initial LLVM IR for the device
   /// module.
-  function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
+  function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const;
 
   /// Returns the callback invoked with LLVM IR for the device module
   /// after linking the device libraries.
-  function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
+  function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const;
 
   /// Returns the callback invoked with LLVM IR for the device module after
   /// LLVM optimizations but before codegen.
-  function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
+  function_ref<LogicalResult(llvm::Module &)>
+  getOptimizedLlvmIRCallback() const;
 
   /// Returns the callback invoked with the target ISA for the device,
   /// for example PTX assembly.
-  function_ref<void(StringRef)> getISACallback() const;
+  function_ref<LogicalResult(StringRef)> getISACallback() const;
 
   /// Returns the default compilation target: `CompilationTarget::Fatbin`.
   static CompilationTarget getDefaultCompilationTarget();
@@ -127,10 +128,10 @@ class TargetOptions {
       StringRef elfSection = {},
       CompilationTarget compilationTarget = getDefaultCompilationTarget(),
       function_ref<SymbolTable *()> getSymbolTableCallback = {},
-      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
-      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
-      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
-      function_ref<void(StringRef)> isaCallback = {});
+      function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<LogicalResult(StringRef)> isaCallback = {});
 
   /// Path to the target toolkit.
   std::string toolkitPath;
@@ -153,19 +154,19 @@ class TargetOptions {
   function_ref<SymbolTable *()> getSymbolTableCallback;
 
   /// Callback invoked with the initial LLVM IR for the device module.
-  function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+  function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
 
   /// Callback invoked with LLVM IR for the device module after
   /// linking the device libraries.
-  function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+  function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
 
   /// Callback invoked with LLVM IR for the device module after
   /// LLVM optimizations but before codegen.
-  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+  function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
 
   /// Callback invoked with the target ISA for the device,
   /// for example PTX assembly.
-  function_ref<void(StringRef)> isaCallback;
+  function_ref<LogicalResult(StringRef)> isaCallback;
 
 private:
   TypeID typeID;
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index 11fea6f0a4443..0edc20cd32620 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -32,10 +32,10 @@ class ModuleToObject {
   ModuleToObject(
       Operation &module, StringRef triple, StringRef chip,
       StringRef features = {}, int optLevel = 3,
-      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
-      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
-      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
-      function_ref<void(StringRef)> isaCallback = {});
+      function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<LogicalResult(StringRef)> isaCallback = {});
   virtual ~ModuleToObject();
 
   /// Returns the operation being serialized.
@@ -120,19 +120,19 @@ class ModuleToObject {
   int optLevel;
 
   /// Callback invoked with the initial LLVM IR for the device module.
-  function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+  function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
 
   /// Callback invoked with LLVM IR for the device module after
   /// linking the device libraries.
-  function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+  function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
 
   /// Callback invoked with LLVM IR for the device module after
   /// LLVM optimizations but before codegen.
-  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+  function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
 
   /// Callback invoked with the target ISA for the device,
   /// for example PTX assembly.
-  function_ref<void(StringRef)> isaCallback;
+  function_ref<LogicalResult(StringRef)> isaCallback;
 
 private:
   /// The TargetMachine created for the given Triple, if available.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 6c6d8d2bad55d..a813608fdf209 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2652,10 +2652,10 @@ TargetOptions::TargetOptions(
     StringRef cmdOptions, StringRef elfSection,
     CompilationTarget compilationTarget,
     function_ref<SymbolTable *()> getSymbolTableCallback,
-    function_ref<void(llvm::Module &)> initialLlvmIRCallback,
-    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
-    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
-    function_ref<void(StringRef)> isaCallback)
+    function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<LogicalResult(StringRef)> isaCallback)
     : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
                     cmdOptions, elfSection, compilationTarget,
                     getSymbolTableCallback, initialLlvmIRCallback,
@@ -2667,10 +2667,10 @@ TargetOptions::TargetOptions(
     StringRef cmdOptions, StringRef elfSection,
     CompilationTarget compilationTarget,
     function_ref<SymbolTable *()> getSymbolTableCallback,
-    function_ref<void(llvm::Module &)> initialLlvmIRCallback,
-    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
-    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
-    function_ref<void(StringRef)> isaCallback)
+    function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<LogicalResult(StringRef)> isaCallback)
     : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
       cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
       compilationTarget(compilationTarget),
@@ -2696,22 +2696,22 @@ SymbolTable *TargetOptions::getSymbolTable() const {
   return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
 }
 
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
 TargetOptions::getInitialLlvmIRCallback() const {
   return initialLlvmIRCallback;
 }
 
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
 TargetOptions::getLinkedLlvmIRCallback() const {
   return linkedLlvmIRCallback;
 }
 
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
 TargetOptions::getOptimizedLlvmIRCallback() const {
   return optimizedLlvmIRCallback;
 }
 
-function_ref<void(StringRef)> TargetOptions::getISACallback() const {
+function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const {
   return isaCallback;
 }
 
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 4098ccc548dc1..d881dda69453b 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -36,10 +36,11 @@ using namespace mlir::LLVM;
 
 ModuleToObject::ModuleToObject(
     Operation &module, StringRef triple, StringRef chip, StringRef features,
-    int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
-    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
-    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
-    function_ref<void(StringRef)> isaCallback)
+    int optLevel,
+    function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<LogicalResult(StringRef)> isaCallback)
     : module(module), triple(triple), chip(chip), features(features),
       optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
       linkedLlvmIRCallback(linkedLlvmIRCallback),
@@ -254,8 +255,12 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
   }
   setDataLayoutAndTriple(*llvmModule);
 
-  if (initialLlvmIRCallback)
-    initialLlvmIRCallback(*llvmModule);
+  if (initialLlvmIRCallback) {
+    if (failed(initialLlvmIRCallback(*llvmModule))) {
+      getOperation().emitError() << "InitialLLVMIRCallback failed.";
+      return std::nullopt;
+    }
+  }
 
   // Link bitcode files.
   handleModulePreLink(*llvmModule);
@@ -269,15 +274,23 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
     handleModulePostLink(*llvmModule);
   }
 
-  if (linkedLlvmIRCallback)
-    linkedLlvmIRCallback(*llvmModule);
+  if (linkedLlvmIRCallback) {
+    if (failed(linkedLlvmIRCallback(*llvmModule))) {
+      getOperation().emitError() << "LinkedLLVMIRCallback failed.";
+      return std::nullopt;
+    }
+  }
 
   // Optimize the module.
   if (failed(optimizeModule(*llvmModule, optLevel)))
     return std::nullopt;
 
-  if (optimizedLlvmIRCallback)
-    optimizedLlvmIRCallback(*llvmModule);
+  if (optimizedLlvmIRCallback) {
+    if (failed(optimizedLlvmIRCallback(*llvmModule))) {
+      getOperation().emitError() << "OptimizedLLVMIRCallback failed.";
+      return std::nullopt;
+    }
+  }
 
   // Return the serialized object.
   return moduleToObject(*llvmModule);
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 8760ea8588e2c..cbd6a6d878813 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -707,8 +707,12 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
     return std::nullopt;
   }
 
-  if (isaCallback)
-    isaCallback(serializedISA.value());
+  if (isaCallback) {
+    if (failed(isaCallback(serializedISA.value()))) {
+      getOperation().emitError() << "ISACallback failed.";
+      return std::nullopt;
+    }
+  }
 
 #define DEBUG_TYPE "serialize-to-isa"
   LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index af0af89c7d07e..1692c4490e4d1 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -176,26 +176,32 @@ TEST_F(MLIRTargetLLVMNVVM,
   ASSERT_TRUE(!!serializer);
 
   std::string initialLLVMIR;
-  auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+  auto initialCallback =
+      [&initialLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(initialLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   std::string linkedLLVMIR;
-  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(linkedLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   std::string optimizedLLVMIR;
-  auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+  auto optimizedCallback =
+      [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(optimizedLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   std::string isaResult;
-  auto isaCallback = [&isaResult](llvm::StringRef isa) {
+  auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult {
     isaResult = isa.str();
+    return success();
   };
 
   gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
@@ -220,6 +226,34 @@ TEST_F(MLIRTargetLLVMNVVM,
   }
 }
 
+// Test callback functions failure with ISA.
+TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) {
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+
+  NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
+
+  auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+  ASSERT_TRUE(!!serializer);
+
+  auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult {
+    return failure();
+  };
+
+  gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
+                             {}, {}, {}, {}, isaCallback);
+
+  for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+    std::optional<SmallVector<char, 0>> object =
+        serializer.serializeToObject(gpuModule, options);
+
+    ASSERT_TRUE(object == std::nullopt);
+  }
+}
+
 // Test linking LLVM IR from a resource attribute.
 TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
   MLIRContext context(registry);
@@ -261,9 +295,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
 
   // Hook to intercept the LLVM IR after linking external libs.
   std::string linkedLLVMIR;
-  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(linkedLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   // Store the bitcode as a DenseI8ArrayAttr.
diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
index 3c880edee4ffc..b392065132787 100644
--- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
@@ -168,9 +168,11 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) {
   auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
 
   std::string initialLLVMIR;
-  auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+  auto initialCallback =
+      [&initialLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(initialLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   gpu::TargetOptions opts(
@@ -196,9 +198,10 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) {
   auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
 
   std::string linkedLLVMIR;
-  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(linkedLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   gpu::TargetOptions opts(
@@ -225,9 +228,11 @@ TEST_F(MLIRTargetLLVM,
   auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
 
   std::string optimizedLLVMIR;
-  auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+  auto optimizedCallback =
+      [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(optimizedLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   gpu::TargetOptions opts(
@@ -240,3 +245,75 @@ TEST_F(MLIRTargetLLVM,
   ASSERT_TRUE(!serializedBinary->empty());
   ASSERT_TRUE(!optimizedLLVMIR.empty());
 }
+
+// Test callback function failure with initial LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) {
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  Builder builder(&context);
+  IntegerAttr target = builder.getI32IntegerAttr(0);
+  auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+  auto initialCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+    return failure();
+  };
+
+  gpu::TargetOptions opts(
+      {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+      {}, initialCallback);
+  std::optional<SmallVector<char, 0>> serializedBinary =
+      targetAttr.serializeToObject(*module, opts);
+
+  ASSERT_TRUE(serializedBinary == std::nullopt);
+}
+
+// Test callback function failure with linked LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) {
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  Builder builder(&context);
+  IntegerAttr target = builder.getI32IntegerAttr(0);
+  auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+  auto linkedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+    return failure();
+  };
+
+  gpu::TargetOptions opts(
+      {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+      {}, {}, linkedCallback);
+  std::optional<SmallVector<char, 0>> serializedBinary =
+      targetAttr.serializeToObject(*module, opts);
+
+  ASSERT_TRUE(serializedBinary == std::nullopt);
+}
+
+// Test callback function failure with optimized LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) {
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  Builder builder(&context);
+  IntegerAttr target = builder.getI32IntegerAttr(0);
+  auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+  auto optimizedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+    return failure();
+  };
+
+  gpu::TargetOptions opts(
+      {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+      {}, {}, {}, optimizedCallback);
+  std::optional<SmallVector<char, 0>> serializedBinary =
+      targetAttr.serializeToObject(*module, opts);
+
+  ASSERT_TRUE(serializedBinary == std::nullopt);
+}

@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2025

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Initial discussion #170016 (comment)

While the initial PR is using these callbacks for debug printing, and filesystem failures are not directly related to this code logic, I can envision passes using these for IR validation and/or module pre/postprocessing which can legitimate fail.


Full diff: https://github.com/llvm/llvm-project/pull/170134.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h (+17-16)
  • (modified) mlir/include/mlir/Target/LLVM/ModuleToObject.h (+8-8)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+12-12)
  • (modified) mlir/lib/Target/LLVM/ModuleToObject.cpp (+23-10)
  • (modified) mlir/lib/Target/LLVM/NVVM/Target.cpp (+6-2)
  • (modified) mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp (+40-5)
  • (modified) mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp (+80-3)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 139360f8bd3fc..00f885898ffa1 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -55,10 +55,10 @@ class TargetOptions {
       StringRef cmdOptions = {}, StringRef elfSection = {},
       CompilationTarget compilationTarget = getDefaultCompilationTarget(),
       function_ref<SymbolTable *()> getSymbolTableCallback = {},
-      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
-      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
-      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
-      function_ref<void(StringRef)> isaCallback = {});
+      function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<LogicalResult(StringRef)> isaCallback = {});
 
   /// Returns the typeID.
   TypeID getTypeID() const;
@@ -97,19 +97,20 @@ class TargetOptions {
 
   /// Returns the callback invoked with the initial LLVM IR for the device
   /// module.
-  function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
+  function_ref<LogicalResult(llvm::Module &)> getInitialLlvmIRCallback() const;
 
   /// Returns the callback invoked with LLVM IR for the device module
   /// after linking the device libraries.
-  function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
+  function_ref<LogicalResult(llvm::Module &)> getLinkedLlvmIRCallback() const;
 
   /// Returns the callback invoked with LLVM IR for the device module after
   /// LLVM optimizations but before codegen.
-  function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
+  function_ref<LogicalResult(llvm::Module &)>
+  getOptimizedLlvmIRCallback() const;
 
   /// Returns the callback invoked with the target ISA for the device,
   /// for example PTX assembly.
-  function_ref<void(StringRef)> getISACallback() const;
+  function_ref<LogicalResult(StringRef)> getISACallback() const;
 
   /// Returns the default compilation target: `CompilationTarget::Fatbin`.
   static CompilationTarget getDefaultCompilationTarget();
@@ -127,10 +128,10 @@ class TargetOptions {
       StringRef elfSection = {},
       CompilationTarget compilationTarget = getDefaultCompilationTarget(),
       function_ref<SymbolTable *()> getSymbolTableCallback = {},
-      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
-      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
-      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
-      function_ref<void(StringRef)> isaCallback = {});
+      function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<LogicalResult(StringRef)> isaCallback = {});
 
   /// Path to the target toolkit.
   std::string toolkitPath;
@@ -153,19 +154,19 @@ class TargetOptions {
   function_ref<SymbolTable *()> getSymbolTableCallback;
 
   /// Callback invoked with the initial LLVM IR for the device module.
-  function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+  function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
 
   /// Callback invoked with LLVM IR for the device module after
   /// linking the device libraries.
-  function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+  function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
 
   /// Callback invoked with LLVM IR for the device module after
   /// LLVM optimizations but before codegen.
-  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+  function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
 
   /// Callback invoked with the target ISA for the device,
   /// for example PTX assembly.
-  function_ref<void(StringRef)> isaCallback;
+  function_ref<LogicalResult(StringRef)> isaCallback;
 
 private:
   TypeID typeID;
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index 11fea6f0a4443..0edc20cd32620 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -32,10 +32,10 @@ class ModuleToObject {
   ModuleToObject(
       Operation &module, StringRef triple, StringRef chip,
       StringRef features = {}, int optLevel = 3,
-      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
-      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
-      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
-      function_ref<void(StringRef)> isaCallback = {});
+      function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<LogicalResult(StringRef)> isaCallback = {});
   virtual ~ModuleToObject();
 
   /// Returns the operation being serialized.
@@ -120,19 +120,19 @@ class ModuleToObject {
   int optLevel;
 
   /// Callback invoked with the initial LLVM IR for the device module.
-  function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+  function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback;
 
   /// Callback invoked with LLVM IR for the device module after
   /// linking the device libraries.
-  function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+  function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback;
 
   /// Callback invoked with LLVM IR for the device module after
   /// LLVM optimizations but before codegen.
-  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+  function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback;
 
   /// Callback invoked with the target ISA for the device,
   /// for example PTX assembly.
-  function_ref<void(StringRef)> isaCallback;
+  function_ref<LogicalResult(StringRef)> isaCallback;
 
 private:
   /// The TargetMachine created for the given Triple, if available.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 6c6d8d2bad55d..a813608fdf209 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2652,10 +2652,10 @@ TargetOptions::TargetOptions(
     StringRef cmdOptions, StringRef elfSection,
     CompilationTarget compilationTarget,
     function_ref<SymbolTable *()> getSymbolTableCallback,
-    function_ref<void(llvm::Module &)> initialLlvmIRCallback,
-    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
-    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
-    function_ref<void(StringRef)> isaCallback)
+    function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<LogicalResult(StringRef)> isaCallback)
     : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
                     cmdOptions, elfSection, compilationTarget,
                     getSymbolTableCallback, initialLlvmIRCallback,
@@ -2667,10 +2667,10 @@ TargetOptions::TargetOptions(
     StringRef cmdOptions, StringRef elfSection,
     CompilationTarget compilationTarget,
     function_ref<SymbolTable *()> getSymbolTableCallback,
-    function_ref<void(llvm::Module &)> initialLlvmIRCallback,
-    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
-    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
-    function_ref<void(StringRef)> isaCallback)
+    function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<LogicalResult(StringRef)> isaCallback)
     : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
       cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
       compilationTarget(compilationTarget),
@@ -2696,22 +2696,22 @@ SymbolTable *TargetOptions::getSymbolTable() const {
   return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
 }
 
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
 TargetOptions::getInitialLlvmIRCallback() const {
   return initialLlvmIRCallback;
 }
 
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
 TargetOptions::getLinkedLlvmIRCallback() const {
   return linkedLlvmIRCallback;
 }
 
-function_ref<void(llvm::Module &)>
+function_ref<LogicalResult(llvm::Module &)>
 TargetOptions::getOptimizedLlvmIRCallback() const {
   return optimizedLlvmIRCallback;
 }
 
-function_ref<void(StringRef)> TargetOptions::getISACallback() const {
+function_ref<LogicalResult(StringRef)> TargetOptions::getISACallback() const {
   return isaCallback;
 }
 
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 4098ccc548dc1..d881dda69453b 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -36,10 +36,11 @@ using namespace mlir::LLVM;
 
 ModuleToObject::ModuleToObject(
     Operation &module, StringRef triple, StringRef chip, StringRef features,
-    int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
-    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
-    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
-    function_ref<void(StringRef)> isaCallback)
+    int optLevel,
+    function_ref<LogicalResult(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<LogicalResult(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<LogicalResult(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<LogicalResult(StringRef)> isaCallback)
     : module(module), triple(triple), chip(chip), features(features),
       optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
       linkedLlvmIRCallback(linkedLlvmIRCallback),
@@ -254,8 +255,12 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
   }
   setDataLayoutAndTriple(*llvmModule);
 
-  if (initialLlvmIRCallback)
-    initialLlvmIRCallback(*llvmModule);
+  if (initialLlvmIRCallback) {
+    if (failed(initialLlvmIRCallback(*llvmModule))) {
+      getOperation().emitError() << "InitialLLVMIRCallback failed.";
+      return std::nullopt;
+    }
+  }
 
   // Link bitcode files.
   handleModulePreLink(*llvmModule);
@@ -269,15 +274,23 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
     handleModulePostLink(*llvmModule);
   }
 
-  if (linkedLlvmIRCallback)
-    linkedLlvmIRCallback(*llvmModule);
+  if (linkedLlvmIRCallback) {
+    if (failed(linkedLlvmIRCallback(*llvmModule))) {
+      getOperation().emitError() << "LinkedLLVMIRCallback failed.";
+      return std::nullopt;
+    }
+  }
 
   // Optimize the module.
   if (failed(optimizeModule(*llvmModule, optLevel)))
     return std::nullopt;
 
-  if (optimizedLlvmIRCallback)
-    optimizedLlvmIRCallback(*llvmModule);
+  if (optimizedLlvmIRCallback) {
+    if (failed(optimizedLlvmIRCallback(*llvmModule))) {
+      getOperation().emitError() << "OptimizedLLVMIRCallback failed.";
+      return std::nullopt;
+    }
+  }
 
   // Return the serialized object.
   return moduleToObject(*llvmModule);
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 8760ea8588e2c..cbd6a6d878813 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -707,8 +707,12 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
     return std::nullopt;
   }
 
-  if (isaCallback)
-    isaCallback(serializedISA.value());
+  if (isaCallback) {
+    if (failed(isaCallback(serializedISA.value()))) {
+      getOperation().emitError() << "ISACallback failed.";
+      return std::nullopt;
+    }
+  }
 
 #define DEBUG_TYPE "serialize-to-isa"
   LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n"
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index af0af89c7d07e..1692c4490e4d1 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -176,26 +176,32 @@ TEST_F(MLIRTargetLLVMNVVM,
   ASSERT_TRUE(!!serializer);
 
   std::string initialLLVMIR;
-  auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+  auto initialCallback =
+      [&initialLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(initialLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   std::string linkedLLVMIR;
-  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(linkedLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   std::string optimizedLLVMIR;
-  auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+  auto optimizedCallback =
+      [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(optimizedLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   std::string isaResult;
-  auto isaCallback = [&isaResult](llvm::StringRef isa) {
+  auto isaCallback = [&isaResult](llvm::StringRef isa) -> LogicalResult {
     isaResult = isa.str();
+    return success();
   };
 
   gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
@@ -220,6 +226,34 @@ TEST_F(MLIRTargetLLVMNVVM,
   }
 }
 
+// Test callback functions failure with ISA.
+TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) {
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+
+  NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
+
+  auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+  ASSERT_TRUE(!!serializer);
+
+  auto isaCallback = [](llvm::StringRef /*isa*/) -> LogicalResult {
+    return failure();
+  };
+
+  gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly,
+                             {}, {}, {}, {}, isaCallback);
+
+  for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+    std::optional<SmallVector<char, 0>> object =
+        serializer.serializeToObject(gpuModule, options);
+
+    ASSERT_TRUE(object == std::nullopt);
+  }
+}
+
 // Test linking LLVM IR from a resource attribute.
 TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
   MLIRContext context(registry);
@@ -261,9 +295,10 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) {
 
   // Hook to intercept the LLVM IR after linking external libs.
   std::string linkedLLVMIR;
-  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(linkedLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   // Store the bitcode as a DenseI8ArrayAttr.
diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
index 3c880edee4ffc..b392065132787 100644
--- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
@@ -168,9 +168,11 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) {
   auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
 
   std::string initialLLVMIR;
-  auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+  auto initialCallback =
+      [&initialLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(initialLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   gpu::TargetOptions opts(
@@ -196,9 +198,10 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) {
   auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
 
   std::string linkedLLVMIR;
-  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+  auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(linkedLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   gpu::TargetOptions opts(
@@ -225,9 +228,11 @@ TEST_F(MLIRTargetLLVM,
   auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
 
   std::string optimizedLLVMIR;
-  auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+  auto optimizedCallback =
+      [&optimizedLLVMIR](llvm::Module &module) -> LogicalResult {
     llvm::raw_string_ostream ros(optimizedLLVMIR);
     module.print(ros, nullptr);
+    return success();
   };
 
   gpu::TargetOptions opts(
@@ -240,3 +245,75 @@ TEST_F(MLIRTargetLLVM,
   ASSERT_TRUE(!serializedBinary->empty());
   ASSERT_TRUE(!optimizedLLVMIR.empty());
 }
+
+// Test callback function failure with initial LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) {
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  Builder builder(&context);
+  IntegerAttr target = builder.getI32IntegerAttr(0);
+  auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+  auto initialCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+    return failure();
+  };
+
+  gpu::TargetOptions opts(
+      {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+      {}, initialCallback);
+  std::optional<SmallVector<char, 0>> serializedBinary =
+      targetAttr.serializeToObject(*module, opts);
+
+  ASSERT_TRUE(serializedBinary == std::nullopt);
+}
+
+// Test callback function failure with linked LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) {
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  Builder builder(&context);
+  IntegerAttr target = builder.getI32IntegerAttr(0);
+  auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+  auto linkedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+    return failure();
+  };
+
+  gpu::TargetOptions opts(
+      {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+      {}, {}, linkedCallback);
+  std::optional<SmallVector<char, 0>> serializedBinary =
+      targetAttr.serializeToObject(*module, opts);
+
+  ASSERT_TRUE(serializedBinary == std::nullopt);
+}
+
+// Test callback function failure with optimized LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) {
+  MLIRContext context(registry);
+
+  OwningOpRef<ModuleOp> module =
+      parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  Builder builder(&context);
+  IntegerAttr target = builder.getI32IntegerAttr(0);
+  auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+  auto optimizedCallback = [](llvm::Module & /*module*/) -> LogicalResult {
+    return failure();
+  };
+
+  gpu::TargetOptions opts(
+      {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(),
+      {}, {}, {}, optimizedCallback);
+  std::optional<SmallVector<char, 0>> serializedBinary =
+      targetAttr.serializeToObject(*module, opts);
+
+  ASSERT_TRUE(serializedBinary == std::nullopt);
+}

Copy link
Contributor

@fabianmcg fabianmcg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general LGTM, the only change I would make is:

  • Passing a Location as well to the callbacks
  • Let the callback handle the error message, so no getOperation().emitError() << "ISACallback failed.";

I'd argue that this is already a wanted change, as I could have a custom out of tree LLVM pass running at any of the LLVM callback levels that could fail.

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a diagnostic issue here: the callbacks don't have any information needed to diagnose (no MLIRContext, no operation to attach the diagnostic to...).

The use-case for this looks quite hand-wavy to me to justify adding more complexity to correctly handle all this though.

@Hardcode84
Copy link
Contributor Author

There is a diagnostic issue here: the callbacks don't have any information needed to diagnose (no MLIRContext, no operation to attach the diagnostic to...).

I can add source op arg to all the callbacks.

The use-case for this looks quite hand-wavy to me to justify adding more complexity to correctly handle all this though.

The one potential use case I considered is something similar to Triton TRITON_OVERRIDE_DIR https://github.com/triton-lang/triton/blob/main/README.md#tips-for-hacking, where you can override an IR from file on specific step for debugging. You need to properly propagate fs/parsing errors in this case obviously.

linkedLlvmIRCallback = {},
function_ref<LogicalResult(Operation *op, llvm::Module &)>
optimizedLlvmIRCallback = {},
function_ref<LogicalResult(Operation *op, StringRef)> isaCallback = {});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this opens up the possibility for the callbacks to modify the MLIR IR, which is a pretty wide opening of the API, I don't think this is what this API is intended to support.

Copy link
Contributor Author

@Hardcode84 Hardcode84 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I can do it though the error callbacks like I did before https://github.com/llvm/llvm-project/blob/main/mlir/lib/Transforms/CompositePass.cpp#L45

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature is strange to me for a errorHandler. Look at the usual pattern in MLIR: https://github.com/search?q=repo%3Allvm%2Fllvm-project%20%22%3E%20emitError%22&type=code

 function_ref<InFlightDiagnostic()> emitError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to InFlightDiagnostic

@joker-eph
Copy link
Collaborator

joker-eph commented Dec 2, 2025

`` https://github.com/triton-lang/triton/blob/main/README.md#tips-for-hacking, where you can override an IR from file on specific step for debugging. You need to properly propagate fs/parsing errors in this case obviously.

This can already be done outside of ModuleToObject the way you've done it in the pass to signalPassFailure, no need to add the complexity here for this.
This is still a failure that is somehow unrelated to the logic internal to ModuleToObject and can be handled by the invoking code just fine.

@joker-eph
Copy link
Collaborator

In general if there is a need to fine grain control the logic across the steps that are encapsulated, then I rather avoid the complexity of the many handshakes with callbacks, and expose the individual steps to allow the users to stitch them however they want. This is a fairly usual approach to modularity in MLIR: a nested grouping of components where you can skip every high level of API to reach to a lower-level API to customize the assembly. The higher-level APIs are conveniently exposed but not the only entry points.
The usual example is how passes upstream can be "simple" wrapper around other "utilities" APIs, but downstream can customize by reimplementing the pass to glue different sets of options or logic around the same "utilities" to get their custom variant of the pass.

@Hardcode84
Copy link
Contributor Author

Hardcode84 commented Dec 3, 2025

If we want to override llvm IR during linking/optimizations phases and/or override device assembly, we will have to do it in those callbacks and we will have callbacks to return errors.

@joker-eph
Copy link
Collaborator

If we want to override llvm IR during linking/optimizations phases and/or override device assembly, we will have to do it in those callbacks and we will have callbacks to return errors.

I don't believe this is needed: proof of existence if what you're doing in the other PR to catch the error.

I'll reiterate that the callback design is creating a monolithic component, which is not modular. We went with this only because the callbacks were intended as "debugging" things that don't affect the behavior of ModuleToObject significantly.

If we need to provide more flexibility to users, the more modular way to achieve it is by disaggregating the component into pieces that the user can assemble, instead of making the "monolith" more complex with complex config/callbacks mechanism. Hence my objection to the current direction here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants