From 881215384dd1f79c5e05c85579b1596f2a769078 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 29 Jan 2025 11:48:35 -0500 Subject: [PATCH] Generalize returns to be ops with ReturnLike trait --- .../FuncBufferizableOpInterfaceImpl.h | 2 +- .../FuncBufferizableOpInterfaceImpl.cpp | 20 +++++++++++-------- .../Transforms/OneShotModuleBufferize.cpp | 12 +++++------ 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h index e8e6226460ac7..caf157b87be87 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h @@ -23,7 +23,7 @@ class FuncOp; namespace bufferization { /// Helper function that returns all func.return ops in the given function. -SmallVector getReturnOps(func::FuncOp funcOp); +SmallVector getReturnOps(func::FuncOp funcOp); namespace func_ext { /// The state of analysis of a FuncOp. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index c45678f1e4b4d..df2fe08d02c90 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -20,11 +20,13 @@ namespace mlir { /// Return all func.return ops in the given function. -SmallVector bufferization::getReturnOps(func::FuncOp funcOp) { - SmallVector result; - for (Block &b : funcOp.getBody()) - if (auto returnOp = dyn_cast(b.getTerminator())) - result.push_back(returnOp); +SmallVector bufferization::getReturnOps(func::FuncOp funcOp) { + SmallVector result; + for (Block &b : funcOp.getBody()) { + Operation *terminator = b.getTerminator(); + if (terminator->hasTrait()) + result.push_back(b.getTerminator()); + } return result; } @@ -439,7 +441,7 @@ struct FuncOpInterface return failure(); // 2. Bufferize the operands of the all return op. - for (func::ReturnOp returnOp : getReturnOps(funcOp)) { + for (Operation *returnOp : getReturnOps(funcOp)) { assert(returnOp->getNumOperands() == retTypes.size() && "incorrect number of return values"); SmallVector returnValues; @@ -457,11 +459,13 @@ struct FuncOpInterface // Note: If `inferFunctionResultLayout = true`, casts are later folded // away. Value toMemrefOp = rewriter.create( - returnOp.getLoc(), bufferizedType, returnVal); + returnOp->getLoc(), bufferizedType, returnVal); returnValues.push_back(toMemrefOp); } - returnOp.getOperandsMutable().assign(returnValues); + for (auto [i, operand] : enumerate(returnValues)) { + returnOp->setOperand(i, operand); + } } // 3. Set the new function type. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 71ea0fd9d43cd..0ba4a1ecf7992 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -133,7 +133,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, } // Find all func.return ops. - SmallVector returnOps = getReturnOps(funcOp); + SmallVector returnOps = getReturnOps(funcOp); assert(!returnOps.empty() && "expected at least one ReturnOp"); // Build alias sets. Merge all aliases from all func.return ops. @@ -142,7 +142,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, int64_t bbArgIdx = bbArg.getArgNumber(); // Store aliases in a set, so that we don't add the same alias twice. SetVector aliases; - for (func::ReturnOp returnOp : returnOps) { + for (Operation *returnOp : returnOps) { for (OpOperand &returnVal : returnOp->getOpOperands()) { if (isa(returnVal.get().getType())) { int64_t returnIdx = returnVal.getOperandNumber(); @@ -192,7 +192,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, // argument for the i-th operand. In contrast to aliasing information, // which is just "merged", equivalence information must match across all // func.return ops. - for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) { + for (Operation *returnOp : ArrayRef(returnOps).drop_front()) { std::optional maybeEquiv = findEquivalentBlockArgIdx(returnOp->getOpOperand(i)); if (maybeEquiv != bbArgIdx) { @@ -398,7 +398,7 @@ static Value unpackCast(Value v) { /// func.return ops. This function returns as many types as the return ops have /// operands. If the i-th operand is not the same for all func.return ops, then /// the i-th returned type is an "empty" type. -static SmallVector getReturnTypes(SmallVector returnOps) { +static SmallVector getReturnTypes(SmallVector returnOps) { assert(!returnOps.empty() && "expected at least one ReturnOp"); int numOperands = returnOps.front()->getNumOperands(); @@ -434,11 +434,11 @@ static void foldMemRefCasts(func::FuncOp funcOp) { return; // Compute the common result types of all return ops. - SmallVector returnOps = getReturnOps(funcOp); + SmallVector returnOps = getReturnOps(funcOp); SmallVector resultTypes = getReturnTypes(returnOps); // Remove direct casts. - for (func::ReturnOp returnOp : returnOps) { + for (Operation *returnOp : returnOps) { for (OpOperand &operand : returnOp->getOpOperands()) { // Bail if no common result type was found. if (resultTypes[operand.getOperandNumber()]) {