Skip to content

Conversation

linuxlonelyeagle
Copy link
Member

Enable Make drop-equivalent-buffer-results to handle return ops in multiple blocks within a function.

@llvmbot
Copy link
Member

llvmbot commented Oct 14, 2025

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir-bufferization

Author: lonely eagle (linuxlonelyeagle)

Changes

Enable Make drop-equivalent-buffer-results to handle return ops in multiple blocks within a function.


Full diff: https://github.com/llvm/llvm-project/pull/163388.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp (+46-22)
  • (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+30)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index 70faa71a5ffbb..9c300cc347ecf 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -41,18 +41,38 @@ namespace bufferization {
 
 using namespace mlir;
 
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
-  func::ReturnOp returnOp;
+/// Get all the ReturnOp in the funcOp.
+static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
+  SmallVector<func::ReturnOp> returnOps;
   for (Block &b : funcOp.getBody()) {
     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
-      if (returnOp)
-        return nullptr;
-      returnOp = candidateOp;
+      returnOps.push_back(candidateOp);
     }
   }
-  return returnOp;
+  return returnOps;
+}
+
+/// Get the values at the same position in the `returnOps`.
+static SmallVector<Value>
+getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
+  SmallVector<Value> operands;
+  for (func::ReturnOp returnOp : returnOps) {
+    operands.push_back(returnOp.getOperand(pos));
+  }
+  return operands;
+}
+
+/// Check if the value in operands is equal to the argument.
+static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
+                                      BlockArgument argument) {
+  for (Value val : operands) {
+    while (auto castOp = val.getDefiningOp<memref::CastOp>())
+      val = castOp.getSource();
+
+    if (val != argument)
+      return false;
+  }
+  return true;
 }
 
 LogicalResult
@@ -72,40 +92,44 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
   for (auto funcOp : module.getOps<func::FuncOp>()) {
     if (funcOp.isExternal() || funcOp.isPublic())
       continue;
-    func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-    // TODO: Support functions with multiple blocks.
-    if (!returnOp)
+    SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+    if (returnOps.empty())
       continue;
+    func::ReturnOp returnOp = returnOps.front();
 
     // Compute erased results.
-    SmallVector<Value> newReturnValues;
+    SmallVector<SmallVector<Value>> newReturnValues(returnOps.size());
     BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
     DenseMap<int64_t, int64_t> resultToArgs;
-    for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
+    for (size_t i = 0, e = returnOp.getOperands().size(); i < e; ++i) {
       bool erased = false;
+      SmallVector<Value> returnOperands =
+          getReturnOpsOperandInPos(returnOps, i);
       for (BlockArgument bbArg : funcOp.getArguments()) {
-        Value val = it.value();
-        while (auto castOp = val.getDefiningOp<memref::CastOp>())
-          val = castOp.getSource();
-
-        if (val == bbArg) {
-          resultToArgs[it.index()] = bbArg.getArgNumber();
+        if (operandsEqualFuncArgument(returnOperands, bbArg)) {
+          resultToArgs[i] = bbArg.getArgNumber();
           erased = true;
           break;
         }
       }
 
       if (erased) {
-        erasedResultIndices.set(it.index());
+        erasedResultIndices.set(i);
       } else {
-        newReturnValues.push_back(it.value());
+        for (auto [newReturnValue, operand] :
+             llvm::zip(newReturnValues, returnOperands)) {
+          newReturnValue.push_back(operand);
+        }
       }
     }
 
     // Update function.
     if (failed(funcOp.eraseResults(erasedResultIndices)))
       return failure();
-    returnOp.getOperandsMutable().assign(newReturnValues);
+
+    for (auto [returnOp, newReturnValue] :
+         llvm::zip(returnOps, newReturnValues))
+      returnOp.getOperandsMutable().assign(newReturnValue);
 
     // Update function calls.
     for (func::CallOp callOp : callerMap[funcOp]) {
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index b6c72bedef6c5..641dcf0d59990 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -490,3 +490,33 @@ func.func @collapse_shape_regression(
   tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32>
   return
 }
+
+// -----
+
+
+// CHECK-LABEL: func private @mult_return_callee(
+//  CHECK-SAME:   %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+//  CHECK-SAME:   %[[A:.*]]: index, %[[B:.*]]: index) -> index {
+func.func private @mult_return_callee(%t: tensor<?xf32>,  %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+  %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+  // CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2
+  // CHECK:  ^bb1:
+  // CHECK:    return %[[A]] : index
+  // CHECK:  ^bb2:
+  // CHECK:    return %[[B]] : index
+  cf.cond_br %cond,^a, ^b
+  ^a:
+    return %casted, %a : tensor<10xf32>, index
+  ^b:
+    return %casted, %b : tensor<10xf32>, index
+}
+
+// CHECK-LABEL: func @mult_return(
+//  CHECK-SAME:   %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+//  CHECK-SAME:   %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) {
+func.func @mult_return(%t: tensor<?xf32>,  %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+  // CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index
+  // CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index
+  %t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index) 
+  return %t_res, %v : tensor<10xf32>, index
+}

@llvmbot
Copy link
Member

llvmbot commented Oct 14, 2025

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

Changes

Enable Make drop-equivalent-buffer-results to handle return ops in multiple blocks within a function.


Full diff: https://github.com/llvm/llvm-project/pull/163388.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp (+46-22)
  • (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+30)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index 70faa71a5ffbb..9c300cc347ecf 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -41,18 +41,38 @@ namespace bufferization {
 
 using namespace mlir;
 
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
-  func::ReturnOp returnOp;
+/// Get all the ReturnOp in the funcOp.
+static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
+  SmallVector<func::ReturnOp> returnOps;
   for (Block &b : funcOp.getBody()) {
     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
-      if (returnOp)
-        return nullptr;
-      returnOp = candidateOp;
+      returnOps.push_back(candidateOp);
     }
   }
-  return returnOp;
+  return returnOps;
+}
+
+/// Get the values at the same position in the `returnOps`.
+static SmallVector<Value>
+getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
+  SmallVector<Value> operands;
+  for (func::ReturnOp returnOp : returnOps) {
+    operands.push_back(returnOp.getOperand(pos));
+  }
+  return operands;
+}
+
+/// Check if the value in operands is equal to the argument.
+static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
+                                      BlockArgument argument) {
+  for (Value val : operands) {
+    while (auto castOp = val.getDefiningOp<memref::CastOp>())
+      val = castOp.getSource();
+
+    if (val != argument)
+      return false;
+  }
+  return true;
 }
 
 LogicalResult
@@ -72,40 +92,44 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
   for (auto funcOp : module.getOps<func::FuncOp>()) {
     if (funcOp.isExternal() || funcOp.isPublic())
       continue;
-    func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-    // TODO: Support functions with multiple blocks.
-    if (!returnOp)
+    SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+    if (returnOps.empty())
       continue;
+    func::ReturnOp returnOp = returnOps.front();
 
     // Compute erased results.
-    SmallVector<Value> newReturnValues;
+    SmallVector<SmallVector<Value>> newReturnValues(returnOps.size());
     BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
     DenseMap<int64_t, int64_t> resultToArgs;
-    for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
+    for (size_t i = 0, e = returnOp.getOperands().size(); i < e; ++i) {
       bool erased = false;
+      SmallVector<Value> returnOperands =
+          getReturnOpsOperandInPos(returnOps, i);
       for (BlockArgument bbArg : funcOp.getArguments()) {
-        Value val = it.value();
-        while (auto castOp = val.getDefiningOp<memref::CastOp>())
-          val = castOp.getSource();
-
-        if (val == bbArg) {
-          resultToArgs[it.index()] = bbArg.getArgNumber();
+        if (operandsEqualFuncArgument(returnOperands, bbArg)) {
+          resultToArgs[i] = bbArg.getArgNumber();
           erased = true;
           break;
         }
       }
 
       if (erased) {
-        erasedResultIndices.set(it.index());
+        erasedResultIndices.set(i);
       } else {
-        newReturnValues.push_back(it.value());
+        for (auto [newReturnValue, operand] :
+             llvm::zip(newReturnValues, returnOperands)) {
+          newReturnValue.push_back(operand);
+        }
       }
     }
 
     // Update function.
     if (failed(funcOp.eraseResults(erasedResultIndices)))
       return failure();
-    returnOp.getOperandsMutable().assign(newReturnValues);
+
+    for (auto [returnOp, newReturnValue] :
+         llvm::zip(returnOps, newReturnValues))
+      returnOp.getOperandsMutable().assign(newReturnValue);
 
     // Update function calls.
     for (func::CallOp callOp : callerMap[funcOp]) {
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index b6c72bedef6c5..641dcf0d59990 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -490,3 +490,33 @@ func.func @collapse_shape_regression(
   tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32>
   return
 }
+
+// -----
+
+
+// CHECK-LABEL: func private @mult_return_callee(
+//  CHECK-SAME:   %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+//  CHECK-SAME:   %[[A:.*]]: index, %[[B:.*]]: index) -> index {
+func.func private @mult_return_callee(%t: tensor<?xf32>,  %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+  %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+  // CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2
+  // CHECK:  ^bb1:
+  // CHECK:    return %[[A]] : index
+  // CHECK:  ^bb2:
+  // CHECK:    return %[[B]] : index
+  cf.cond_br %cond,^a, ^b
+  ^a:
+    return %casted, %a : tensor<10xf32>, index
+  ^b:
+    return %casted, %b : tensor<10xf32>, index
+}
+
+// CHECK-LABEL: func @mult_return(
+//  CHECK-SAME:   %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
+//  CHECK-SAME:   %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) {
+func.func @mult_return(%t: tensor<?xf32>,  %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
+  // CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index
+  // CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index
+  %t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index) 
+  return %t_res, %v : tensor<10xf32>, index
+}

@linuxlonelyeagle linuxlonelyeagle force-pushed the drop-equivalent-buffer-results-mult-block branch from 191d6f0 to 5c4c6de Compare October 14, 2025 16:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:bufferization Bufferization infrastructure mlir:tensor mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants