Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,38 @@ namespace bufferization {

using namespace mlir;

/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
func::ReturnOp returnOp;
/// Get all the ReturnOp in the funcOp.
static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
SmallVector<func::ReturnOp> returnOps;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
if (returnOp)
return nullptr;
returnOp = candidateOp;
returnOps.push_back(candidateOp);
}
}
return returnOp;
return returnOps;
}

/// Get the values at the same position in the `returnOps`.
static SmallVector<Value>
getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
SmallVector<Value> operands;
for (func::ReturnOp returnOp : returnOps) {
operands.push_back(returnOp.getOperand(pos));
}
return operands;
}

/// Check if the value in operands is equal to the argument.
static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
BlockArgument argument) {
for (Value val : operands) {
while (auto castOp = val.getDefiningOp<memref::CastOp>())
val = castOp.getSource();

if (val != argument)
return false;
}
return true;
}

LogicalResult
Expand All @@ -72,40 +92,44 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
for (auto funcOp : module.getOps<func::FuncOp>()) {
if (funcOp.isExternal() || funcOp.isPublic())
continue;
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
// TODO: Support functions with multiple blocks.
if (!returnOp)
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
if (returnOps.empty())
continue;
func::ReturnOp returnOp = returnOps.front();

// Compute erased results.
SmallVector<Value> newReturnValues;
SmallVector<SmallVector<Value>> newReturnValues(returnOps.size());
BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
DenseMap<int64_t, int64_t> resultToArgs;
for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
for (size_t i = 0, e = returnOp.getOperands().size(); i < e; ++i) {
bool erased = false;
SmallVector<Value> returnOperands =
getReturnOpsOperandInPos(returnOps, i);
for (BlockArgument bbArg : funcOp.getArguments()) {
Value val = it.value();
while (auto castOp = val.getDefiningOp<memref::CastOp>())
val = castOp.getSource();

if (val == bbArg) {
resultToArgs[it.index()] = bbArg.getArgNumber();
if (operandsEqualFuncArgument(returnOperands, bbArg)) {
resultToArgs[i] = bbArg.getArgNumber();
erased = true;
break;
}
}

if (erased) {
erasedResultIndices.set(it.index());
erasedResultIndices.set(i);
} else {
newReturnValues.push_back(it.value());
for (auto [newReturnValue, operand] :
llvm::zip(newReturnValues, returnOperands)) {
newReturnValue.push_back(operand);
}
}
}

// Update function.
if (failed(funcOp.eraseResults(erasedResultIndices)))
return failure();
returnOp.getOperandsMutable().assign(newReturnValues);

for (auto [returnOp, newReturnValue] :
llvm::zip(returnOps, newReturnValues))
returnOp.getOperandsMutable().assign(newReturnValue);

// Update function calls.
for (func::CallOp callOp : callerMap[funcOp]) {
Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,32 @@ func.func @collapse_shape_regression(
tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32>
return
}

// -----

// CHECK-LABEL: func private @mult_return_callee(
// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> index {
func.func private @mult_return_callee(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
%casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
// CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2
// CHECK: ^bb1:
// CHECK: return %[[A]] : index
// CHECK: ^bb2:
// CHECK: return %[[B]] : index
cf.cond_br %cond,^a, ^b
^a:
return %casted, %a : tensor<10xf32>, index
^b:
return %casted, %b : tensor<10xf32>, index
}

// CHECK-LABEL: func @mult_return(
// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) {
func.func @mult_return(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
// CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index
// CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index
%t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index)
return %t_res, %v : tensor<10xf32>, index
}