diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index 70faa71a5ffbb..9c300cc347ecf 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -41,18 +41,38 @@ namespace bufferization { using namespace mlir; -/// Return the unique ReturnOp that terminates `funcOp`. -/// Return nullptr if there is no such unique ReturnOp. -static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { - func::ReturnOp returnOp; +/// Get all the ReturnOp in the funcOp. +static SmallVector getReturnOps(func::FuncOp funcOp) { + SmallVector returnOps; for (Block &b : funcOp.getBody()) { if (auto candidateOp = dyn_cast(b.getTerminator())) { - if (returnOp) - return nullptr; - returnOp = candidateOp; + returnOps.push_back(candidateOp); } } - return returnOp; + return returnOps; +} + +/// Get the values at the same position in the `returnOps`. +static SmallVector +getReturnOpsOperandInPos(ArrayRef returnOps, size_t pos) { + SmallVector operands; + for (func::ReturnOp returnOp : returnOps) { + operands.push_back(returnOp.getOperand(pos)); + } + return operands; +} + +/// Check if the value in operands is equal to the argument. +static bool operandsEqualFuncArgument(ArrayRef operands, + BlockArgument argument) { + for (Value val : operands) { + while (auto castOp = val.getDefiningOp()) + val = castOp.getSource(); + + if (val != argument) + return false; + } + return true; } LogicalResult @@ -72,40 +92,44 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { for (auto funcOp : module.getOps()) { if (funcOp.isExternal() || funcOp.isPublic()) continue; - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - // TODO: Support functions with multiple blocks. - if (!returnOp) + SmallVector returnOps = getReturnOps(funcOp); + if (returnOps.empty()) continue; + func::ReturnOp returnOp = returnOps.front(); // Compute erased results. - SmallVector newReturnValues; + SmallVector> newReturnValues(returnOps.size()); BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults()); DenseMap resultToArgs; - for (const auto &it : llvm::enumerate(returnOp.getOperands())) { + for (size_t i = 0, e = returnOp.getOperands().size(); i < e; ++i) { bool erased = false; + SmallVector returnOperands = + getReturnOpsOperandInPos(returnOps, i); for (BlockArgument bbArg : funcOp.getArguments()) { - Value val = it.value(); - while (auto castOp = val.getDefiningOp()) - val = castOp.getSource(); - - if (val == bbArg) { - resultToArgs[it.index()] = bbArg.getArgNumber(); + if (operandsEqualFuncArgument(returnOperands, bbArg)) { + resultToArgs[i] = bbArg.getArgNumber(); erased = true; break; } } if (erased) { - erasedResultIndices.set(it.index()); + erasedResultIndices.set(i); } else { - newReturnValues.push_back(it.value()); + for (auto [newReturnValue, operand] : + llvm::zip(newReturnValues, returnOperands)) { + newReturnValue.push_back(operand); + } } } // Update function. if (failed(funcOp.eraseResults(erasedResultIndices))) return failure(); - returnOp.getOperandsMutable().assign(newReturnValues); + + for (auto [returnOp, newReturnValue] : + llvm::zip(returnOps, newReturnValues)) + returnOp.getOperandsMutable().assign(newReturnValue); // Update function calls. for (func::CallOp callOp : callerMap[funcOp]) { diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir index b6c72bedef6c5..508e29303d37b 100644 --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -490,3 +490,32 @@ func.func @collapse_shape_regression( tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32> return } + +// ----- + +// CHECK-LABEL: func private @mult_return_callee( +// CHECK-SAME: %[[T:.*]]: memref>, %[[COND:.*]]: i1, +// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> index { +func.func private @mult_return_callee(%t: tensor, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) { + %casted = tensor.cast %t : tensor to tensor<10xf32> + // CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2 + // CHECK: ^bb1: + // CHECK: return %[[A]] : index + // CHECK: ^bb2: + // CHECK: return %[[B]] : index + cf.cond_br %cond,^a, ^b + ^a: + return %casted, %a : tensor<10xf32>, index + ^b: + return %casted, %b : tensor<10xf32>, index +} + +// CHECK-LABEL: func @mult_return( +// CHECK-SAME: %[[T:.*]]: memref>, %[[COND:.*]]: i1, +// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> (memref>, index) { +func.func @mult_return(%t: tensor, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) { + // CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref>, i1, index, index) -> index + // CHECK: return %[[T]], %[[RET]] : memref>, index + %t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor, i1, index, index) -> (tensor<10xf32>, index) + return %t_res, %v : tensor<10xf32>, index +}