@@ -41,18 +41,37 @@ namespace bufferization {
4141
4242using namespace mlir ;
4343
44- // / Return the unique ReturnOp that terminates `funcOp`.
45- // / Return nullptr if there is no such unique ReturnOp.
46- static func::ReturnOp getAssumedUniqueReturnOp (func::FuncOp funcOp) {
47- func::ReturnOp returnOp;
44+ // / Get all the ReturnOp in the funcOp.
45+ static SmallVector<func::ReturnOp> getReturnOps (func::FuncOp funcOp) {
46+ SmallVector<func::ReturnOp> returnOps;
4847 for (Block &b : funcOp.getBody ()) {
4948 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator ())) {
50- if (returnOp)
51- return nullptr ;
52- returnOp = candidateOp;
49+ returnOps.push_back (candidateOp);
5350 }
5451 }
55- return returnOp;
52+ return returnOps;
53+ }
54+
55+ // / Get the operands at the specified position for all returnOps.
56+ static SmallVector<Value>
57+ getReturnOpsOperandInPos (ArrayRef<func::ReturnOp> returnOps, size_t pos) {
58+ return llvm::map_to_vector (returnOps, [&](func::ReturnOp returnOp) {
59+ return returnOp.getOperand (pos);
60+ });
61+ }
62+
63+ // / Check if all given values are the same buffer as the block argument (modulo
64+ // / cast ops).
65+ static bool operandsEqualFuncArgument (ArrayRef<Value> operands,
66+ BlockArgument argument) {
67+ for (Value val : operands) {
68+ while (auto castOp = val.getDefiningOp <memref::CastOp>())
69+ val = castOp.getSource ();
70+
71+ if (val != argument)
72+ return false ;
73+ }
74+ return true ;
5675}
5776
5877LogicalResult
@@ -72,40 +91,45 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
7291 for (auto funcOp : module .getOps <func::FuncOp>()) {
7392 if (funcOp.isExternal () || funcOp.isPublic ())
7493 continue ;
75- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
76- // TODO: Support functions with multiple blocks.
77- if (!returnOp)
94+ SmallVector<func::ReturnOp> returnOps = getReturnOps (funcOp);
95+ if (returnOps.empty ())
7896 continue ;
7997
8098 // Compute erased results.
81- SmallVector<Value> newReturnValues;
82- BitVector erasedResultIndices (funcOp.getFunctionType ().getNumResults ());
99+ size_t numReturnOps = returnOps.size ();
100+ size_t numReturnValues = funcOp.getFunctionType ().getNumResults ();
101+ SmallVector<SmallVector<Value>> newReturnValues (numReturnOps);
102+ BitVector erasedResultIndices (numReturnValues);
83103 DenseMap<int64_t , int64_t > resultToArgs;
84- for (const auto &it : llvm::enumerate (returnOp. getOperands ()) ) {
104+ for (size_t i = 0 ; i < numReturnValues; ++i ) {
85105 bool erased = false ;
106+ SmallVector<Value> returnOperands =
107+ getReturnOpsOperandInPos (returnOps, i);
86108 for (BlockArgument bbArg : funcOp.getArguments ()) {
87- Value val = it.value ();
88- while (auto castOp = val.getDefiningOp <memref::CastOp>())
89- val = castOp.getSource ();
90-
91- if (val == bbArg) {
92- resultToArgs[it.index ()] = bbArg.getArgNumber ();
109+ if (operandsEqualFuncArgument (returnOperands, bbArg)) {
110+ resultToArgs[i] = bbArg.getArgNumber ();
93111 erased = true ;
94112 break ;
95113 }
96114 }
97115
98116 if (erased) {
99- erasedResultIndices.set (it. index () );
117+ erasedResultIndices.set (i );
100118 } else {
101- newReturnValues.push_back (it.value ());
119+ for (auto [newReturnValue, operand] :
120+ llvm::zip (newReturnValues, returnOperands)) {
121+ newReturnValue.push_back (operand);
122+ }
102123 }
103124 }
104125
105126 // Update function.
106127 if (failed (funcOp.eraseResults (erasedResultIndices)))
107128 return failure ();
108- returnOp.getOperandsMutable ().assign (newReturnValues);
129+
130+ for (auto [returnOp, newReturnValue] :
131+ llvm::zip (returnOps, newReturnValues))
132+ returnOp.getOperandsMutable ().assign (newReturnValue);
109133
110134 // Update function calls.
111135 for (func::CallOp callOp : callerMap[funcOp]) {
0 commit comments