Skip to content

Commit 8812153

Browse files
committed
Generalize returns to be ops with ReturnLike trait
1 parent e0054e9 commit 8812153

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class FuncOp;
2323

2424
namespace bufferization {
2525
/// Helper function that returns all func.return ops in the given function.
26-
SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp);
26+
SmallVector<Operation *> getReturnOps(func::FuncOp funcOp);
2727

2828
namespace func_ext {
2929
/// The state of analysis of a FuncOp.

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020

2121
namespace mlir {
2222
/// Return all func.return ops in the given function.
23-
SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
24-
SmallVector<func::ReturnOp> result;
25-
for (Block &b : funcOp.getBody())
26-
if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
27-
result.push_back(returnOp);
23+
SmallVector<Operation *> bufferization::getReturnOps(func::FuncOp funcOp) {
24+
SmallVector<Operation *> result;
25+
for (Block &b : funcOp.getBody()) {
26+
Operation *terminator = b.getTerminator();
27+
if (terminator->hasTrait<OpTrait::ReturnLike>())
28+
result.push_back(b.getTerminator());
29+
}
2830
return result;
2931
}
3032

@@ -439,7 +441,7 @@ struct FuncOpInterface
439441
return failure();
440442

441443
// 2. Bufferize the operands of the all return op.
442-
for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
444+
for (Operation *returnOp : getReturnOps(funcOp)) {
443445
assert(returnOp->getNumOperands() == retTypes.size() &&
444446
"incorrect number of return values");
445447
SmallVector<Value> returnValues;
@@ -457,11 +459,13 @@ struct FuncOpInterface
457459
// Note: If `inferFunctionResultLayout = true`, casts are later folded
458460
// away.
459461
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
460-
returnOp.getLoc(), bufferizedType, returnVal);
462+
returnOp->getLoc(), bufferizedType, returnVal);
461463
returnValues.push_back(toMemrefOp);
462464
}
463465

464-
returnOp.getOperandsMutable().assign(returnValues);
466+
for (auto [i, operand] : enumerate(returnValues)) {
467+
returnOp->setOperand(i, operand);
468+
}
465469
}
466470

467471
// 3. Set the new function type.

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
133133
}
134134

135135
// Find all func.return ops.
136-
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
136+
SmallVector<Operation *> returnOps = getReturnOps(funcOp);
137137
assert(!returnOps.empty() && "expected at least one ReturnOp");
138138

139139
// Build alias sets. Merge all aliases from all func.return ops.
@@ -142,7 +142,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
142142
int64_t bbArgIdx = bbArg.getArgNumber();
143143
// Store aliases in a set, so that we don't add the same alias twice.
144144
SetVector<int64_t> aliases;
145-
for (func::ReturnOp returnOp : returnOps) {
145+
for (Operation *returnOp : returnOps) {
146146
for (OpOperand &returnVal : returnOp->getOpOperands()) {
147147
if (isa<RankedTensorType>(returnVal.get().getType())) {
148148
int64_t returnIdx = returnVal.getOperandNumber();
@@ -192,7 +192,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
192192
// argument for the i-th operand. In contrast to aliasing information,
193193
// which is just "merged", equivalence information must match across all
194194
// func.return ops.
195-
for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
195+
for (Operation *returnOp : ArrayRef(returnOps).drop_front()) {
196196
std::optional<int64_t> maybeEquiv =
197197
findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
198198
if (maybeEquiv != bbArgIdx) {
@@ -398,7 +398,7 @@ static Value unpackCast(Value v) {
398398
/// func.return ops. This function returns as many types as the return ops have
399399
/// operands. If the i-th operand is not the same for all func.return ops, then
400400
/// the i-th returned type is an "empty" type.
401-
static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
401+
static SmallVector<Type> getReturnTypes(SmallVector<Operation *> returnOps) {
402402
assert(!returnOps.empty() && "expected at least one ReturnOp");
403403
int numOperands = returnOps.front()->getNumOperands();
404404

@@ -434,11 +434,11 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
434434
return;
435435

436436
// Compute the common result types of all return ops.
437-
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
437+
SmallVector<Operation *> returnOps = getReturnOps(funcOp);
438438
SmallVector<Type> resultTypes = getReturnTypes(returnOps);
439439

440440
// Remove direct casts.
441-
for (func::ReturnOp returnOp : returnOps) {
441+
for (Operation *returnOp : returnOps) {
442442
for (OpOperand &operand : returnOp->getOpOperands()) {
443443
// Bail if no common result type was found.
444444
if (resultTypes[operand.getOperandNumber()]) {

0 commit comments

Comments
 (0)