@@ -88,6 +88,7 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
88
88
89
89
// / Return the unique ReturnOp that terminates `funcOp`.
90
90
// / Return nullptr if there is no such unique ReturnOp.
91
+ // / Return `funcOp` it self if there is no ReturnOp.
91
92
static Operation* getAssumedUniqueReturnOp (FunctionOpInterface funcOp) {
92
93
Operation *returnOp = nullptr ;
93
94
for (Block &b : funcOp.getFunctionBody ()) {
@@ -98,6 +99,8 @@ static Operation* getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
98
99
returnOp = candidateOp;
99
100
}
100
101
}
102
+ if (!returnOp)
103
+ return funcOp;
101
104
return returnOp;
102
105
}
103
106
@@ -147,9 +150,10 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &s
147
150
}
148
151
149
152
// Support only single return-terminated block in the function.
150
- if (!isa<func::FuncOp>(funcOp))
151
- return success ();
153
+ // If funcOp has no returnOp, skip the following analysis.
152
154
Operation *returnOp = getAssumedUniqueReturnOp (funcOp);
155
+ if (returnOp == funcOp)
156
+ return success ();
153
157
assert (returnOp && " expected func with single return op" );
154
158
155
159
for (OpOperand &returnVal : returnOp->getOpOperands ())
@@ -300,9 +304,9 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
300
304
// For each FuncOp, the number of func::CallOp it contains.
301
305
DenseMap<FunctionOpInterface, unsigned > numberCallOpsContainedInFuncOp;
302
306
WalkResult res = moduleOp.walk ([&](FunctionOpInterface funcOp) -> WalkResult {
303
- if (!funcOp.getFunctionBody ().empty () && isa<func::FuncOp>(funcOp) ) {
307
+ if (!funcOp.getFunctionBody ().empty ()) {
304
308
Operation *returnOp = getAssumedUniqueReturnOp (funcOp);
305
- if (!returnOp)
309
+ if (!returnOp && returnOp != funcOp )
306
310
return funcOp->emitError ()
307
311
<< " cannot bufferize a FuncOp with tensors and "
308
312
" without a unique ReturnOp" ;
@@ -356,7 +360,7 @@ static void foldMemRefCasts(FunctionOpInterface funcOp) {
356
360
357
361
Operation *returnOp = getAssumedUniqueReturnOp (funcOp);
358
362
359
- if (!returnOp)
363
+ if (!returnOp || returnOp == funcOp )
360
364
return ;
361
365
362
366
SmallVector<Type> resultTypes;
0 commit comments