Skip to content

Commit 50d65a5

Browse files
[mlir][bufferize] Make drop-equivalent-buffer-results support mult blocks (#163388)
Enable Make drop-equivalent-buffer-results to handle return ops in multiple blocks within a function.
1 parent 23341c3 commit 50d65a5

File tree

2 files changed

+76
-23
lines changed

2 files changed

+76
-23
lines changed

mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,37 @@ namespace bufferization {
4141

4242
using namespace mlir;
4343

44-
/// Return the unique ReturnOp that terminates `funcOp`.
45-
/// Return nullptr if there is no such unique ReturnOp.
46-
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
47-
func::ReturnOp returnOp;
44+
/// Get all the ReturnOp in the funcOp.
45+
static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
46+
SmallVector<func::ReturnOp> returnOps;
4847
for (Block &b : funcOp.getBody()) {
4948
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
50-
if (returnOp)
51-
return nullptr;
52-
returnOp = candidateOp;
49+
returnOps.push_back(candidateOp);
5350
}
5451
}
55-
return returnOp;
52+
return returnOps;
53+
}
54+
55+
/// Get the operands at the specified position for all returnOps.
56+
static SmallVector<Value>
57+
getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
58+
return llvm::map_to_vector(returnOps, [&](func::ReturnOp returnOp) {
59+
return returnOp.getOperand(pos);
60+
});
61+
}
62+
63+
/// Check if all given values are the same buffer as the block argument (modulo
64+
/// cast ops).
65+
static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
66+
BlockArgument argument) {
67+
for (Value val : operands) {
68+
while (auto castOp = val.getDefiningOp<memref::CastOp>())
69+
val = castOp.getSource();
70+
71+
if (val != argument)
72+
return false;
73+
}
74+
return true;
5675
}
5776

5877
LogicalResult
@@ -72,40 +91,45 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
7291
for (auto funcOp : module.getOps<func::FuncOp>()) {
7392
if (funcOp.isExternal() || funcOp.isPublic())
7493
continue;
75-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
76-
// TODO: Support functions with multiple blocks.
77-
if (!returnOp)
94+
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
95+
if (returnOps.empty())
7896
continue;
7997

8098
// Compute erased results.
81-
SmallVector<Value> newReturnValues;
82-
BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
99+
size_t numReturnOps = returnOps.size();
100+
size_t numReturnValues = funcOp.getFunctionType().getNumResults();
101+
SmallVector<SmallVector<Value>> newReturnValues(numReturnOps);
102+
BitVector erasedResultIndices(numReturnValues);
83103
DenseMap<int64_t, int64_t> resultToArgs;
84-
for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
104+
for (size_t i = 0; i < numReturnValues; ++i) {
85105
bool erased = false;
106+
SmallVector<Value> returnOperands =
107+
getReturnOpsOperandInPos(returnOps, i);
86108
for (BlockArgument bbArg : funcOp.getArguments()) {
87-
Value val = it.value();
88-
while (auto castOp = val.getDefiningOp<memref::CastOp>())
89-
val = castOp.getSource();
90-
91-
if (val == bbArg) {
92-
resultToArgs[it.index()] = bbArg.getArgNumber();
109+
if (operandsEqualFuncArgument(returnOperands, bbArg)) {
110+
resultToArgs[i] = bbArg.getArgNumber();
93111
erased = true;
94112
break;
95113
}
96114
}
97115

98116
if (erased) {
99-
erasedResultIndices.set(it.index());
117+
erasedResultIndices.set(i);
100118
} else {
101-
newReturnValues.push_back(it.value());
119+
for (auto [newReturnValue, operand] :
120+
llvm::zip(newReturnValues, returnOperands)) {
121+
newReturnValue.push_back(operand);
122+
}
102123
}
103124
}
104125

105126
// Update function.
106127
if (failed(funcOp.eraseResults(erasedResultIndices)))
107128
return failure();
108-
returnOp.getOperandsMutable().assign(newReturnValues);
129+
130+
for (auto [returnOp, newReturnValue] :
131+
llvm::zip(returnOps, newReturnValues))
132+
returnOp.getOperandsMutable().assign(newReturnValue);
109133

110134
// Update function calls.
111135
for (func::CallOp callOp : callerMap[funcOp]) {

mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,3 +490,32 @@ func.func @collapse_shape_regression(
490490
tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32>
491491
return
492492
}
493+
494+
// -----
495+
496+
// CHECK-LABEL: func private @mult_return_callee(
497+
// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
498+
// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> index {
499+
// CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2
500+
// CHECK: ^bb1:
501+
// CHECK: return %[[A]] : index
502+
// CHECK: ^bb2:
503+
// CHECK: return %[[B]] : index
504+
func.func private @mult_return_callee(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
505+
%casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
506+
cf.cond_br %cond,^a, ^b
507+
^a:
508+
return %casted, %a : tensor<10xf32>, index
509+
^b:
510+
return %casted, %b : tensor<10xf32>, index
511+
}
512+
513+
// CHECK-LABEL: func @mult_return(
514+
// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
515+
// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) {
516+
func.func @mult_return(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
517+
// CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index
518+
// CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index
519+
%t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index)
520+
return %t_res, %v : tensor<10xf32>, index
521+
}

0 commit comments

Comments
 (0)