@@ -88,10 +88,11 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8888
8989// / Return the unique ReturnOp that terminates `funcOp`.
9090// / Return nullptr if there is no such unique ReturnOp.
91- static func::ReturnOp getAssumedUniqueReturnOp (func::FuncOp funcOp) {
92- func::ReturnOp returnOp;
93- for (Block &b : funcOp.getBody ()) {
94- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator ())) {
91+ static Operation* getAssumedUniqueReturnOp (FunctionOpInterface funcOp) {
92+ Operation *returnOp = nullptr ;
93+ for (Block &b : funcOp.getFunctionBody ()) {
94+ auto candidateOp = b.getTerminator ();
95+ if (candidateOp && candidateOp->hasTrait <OpTrait::ReturnLike>()) {
9596 if (returnOp)
9697 return nullptr ;
9798 returnOp = candidateOp;
@@ -126,16 +127,15 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
126127// / Store function BlockArguments that are equivalent to/aliasing a returned
127128// / value in FuncAnalysisState.
128129static LogicalResult
129- aliasingFuncOpBBArgsAnalysis (FuncOp funcOp, OneShotAnalysisState &state,
130+ aliasingFuncOpBBArgsAnalysis (FunctionOpInterface funcOp, OneShotAnalysisState &state,
130131 FuncAnalysisState &funcState) {
131- if (funcOp.getBody ().empty ()) {
132+ if (funcOp.getFunctionBody ().empty ()) {
132133 // No function body available. Conservatively assume that every tensor
133134 // return value may alias with any tensor bbArg.
134- FunctionType type = funcOp.getFunctionType ();
135- for (const auto &inputIt : llvm::enumerate (type.getInputs ())) {
135+ for (const auto &inputIt : llvm::enumerate (funcOp.getArgumentTypes ())) {
136136 if (!isa<TensorType>(inputIt.value ()))
137137 continue ;
138- for (const auto &resultIt : llvm::enumerate (type. getResults ())) {
138+ for (const auto &resultIt : llvm::enumerate (funcOp. getResultTypes ())) {
139139 if (!isa<TensorType>(resultIt.value ()))
140140 continue ;
141141 int64_t returnIdx = resultIt.index ();
@@ -147,7 +147,9 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
147147 }
148148
149149 // Support only single return-terminated block in the function.
150- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
150+ if (!isa<func::FuncOp>(funcOp))
151+ return success ();
152+ Operation *returnOp = getAssumedUniqueReturnOp (funcOp);
151153 assert (returnOp && " expected func with single return op" );
152154
153155 for (OpOperand &returnVal : returnOp->getOpOperands ())
@@ -168,7 +170,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
168170 return success ();
169171}
170172
171- static void annotateFuncArgAccess (func::FuncOp funcOp, int64_t idx, bool isRead,
173+ static void annotateFuncArgAccess (FunctionOpInterface funcOp, int64_t idx, bool isRead,
172174 bool isWritten) {
173175 OpBuilder b (funcOp.getContext ());
174176 Attribute accessType;
@@ -189,12 +191,12 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
189191// / function with unknown ops, we conservatively assume that such ops bufferize
190192// / to a read + write.
191193static LogicalResult
192- funcOpBbArgReadWriteAnalysis (FuncOp funcOp, OneShotAnalysisState &state,
194+ funcOpBbArgReadWriteAnalysis (FunctionOpInterface funcOp, OneShotAnalysisState &state,
193195 FuncAnalysisState &funcState) {
194- for (int64_t idx = 0 , e = funcOp.getFunctionType (). getNumInputs (); idx < e;
196+ for (int64_t idx = 0 , e = funcOp.getNumArguments (); idx < e;
195197 ++idx) {
196198 // Skip non-tensor arguments.
197- if (!isa<TensorType>(funcOp.getFunctionType (). getInput ( idx) ))
199+ if (!isa<TensorType>(funcOp.getArgumentTypes ()[ idx] ))
198200 continue ;
199201 bool isRead;
200202 bool isWritten;
@@ -204,7 +206,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
204206 StringRef str = accessAttr.getValue ();
205207 isRead = str == " read" || str == " read-write" ;
206208 isWritten = str == " write" || str == " read-write" ;
207- } else if (funcOp.getBody ().empty ()) {
209+ } else if (funcOp.getFunctionBody ().empty ()) {
208210 // If the function has no body, conservatively assume that all args are
209211 // read + written.
210212 isRead = true ;
@@ -230,23 +232,13 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
230232
231233// / Remove bufferization attributes on FuncOp arguments.
232234static void removeBufferizationAttributes (BlockArgument bbArg) {
233- auto funcOp = cast<func::FuncOp >(bbArg.getOwner ()->getParentOp ());
235+ auto funcOp = cast<FunctionOpInterface >(bbArg.getOwner ()->getParentOp ());
234236 funcOp.removeArgAttr (bbArg.getArgNumber (),
235237 BufferizationDialect::kBufferLayoutAttrName );
236238 funcOp.removeArgAttr (bbArg.getArgNumber (),
237239 BufferizationDialect::kWritableAttrName );
238240}
239241
240- // / Return the func::FuncOp called by `callOp`.
241- static func::FuncOp getCalledFunction (func::CallOp callOp) {
242- SymbolRefAttr sym =
243- llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
244- if (!sym)
245- return nullptr ;
246- return dyn_cast_or_null<func::FuncOp>(
247- SymbolTable::lookupNearestSymbolFrom (callOp, sym));
248- }
249-
250242static FunctionOpInterface getCalledFunction (CallOpInterface callOp) {
251243 SymbolRefAttr sym =
252244 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
@@ -260,12 +252,12 @@ static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
260252// / Note: This only adds new equivalence info if the called function was already
261253// / analyzed.
262254// TODO: This does not handle cyclic function call graphs etc.
263- static void equivalenceAnalysis (func::FuncOp funcOp,
255+ static void equivalenceAnalysis (FunctionOpInterface funcOp,
264256 OneShotAnalysisState &state,
265257 FuncAnalysisState &funcState) {
266- funcOp->walk ([&](func::CallOp callOp) {
267- func::FuncOp calledFunction = getCalledFunction (callOp);
268- assert (calledFunction && " could not retrieved called func::FuncOp " );
258+ funcOp->walk ([&](CallOpInterface callOp) {
259+ FunctionOpInterface calledFunction = getCalledFunction (callOp);
260+ assert (calledFunction && " could not retrieved called FunctionOpInterface " );
269261
270262 // No equivalence info available for the called function.
271263 if (!funcState.equivalentFuncArgs .count (calledFunction))
@@ -276,7 +268,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
276268 int64_t bbargIdx = it.second ;
277269 if (!state.isInPlace (callOp->getOpOperand (bbargIdx)))
278270 continue ;
279- Value returnVal = callOp. getResult (returnIdx);
271+ Value returnVal = callOp-> getResult (returnIdx);
280272 Value argVal = callOp->getOperand (bbargIdx);
281273 state.unionEquivalenceClasses (returnVal, argVal);
282274 }
@@ -308,23 +300,19 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
308300 // For each FuncOp, the number of func::CallOp it contains.
309301 DenseMap<FunctionOpInterface, unsigned > numberCallOpsContainedInFuncOp;
310302 WalkResult res = moduleOp.walk ([&](FunctionOpInterface funcOp) -> WalkResult {
311- // Only handle ReturnOp if funcOp is exactly the FuncOp type.
312- if (isa<FuncOp>(funcOp)) {
313- FuncOp funcOpCasted = cast<FuncOp>(funcOp);
314- if (!funcOpCasted.getBody ().empty ()) {
315- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOpCasted);
316- if (!returnOp)
317- return funcOp->emitError ()
318- << " cannot bufferize a FuncOp with tensors and "
319- " without a unique ReturnOp" ;
320- }
303+ if (!funcOp.getFunctionBody ().empty () && isa<func::FuncOp>(funcOp)) {
304+ Operation *returnOp = getAssumedUniqueReturnOp (funcOp);
305+ if (!returnOp)
306+ return funcOp->emitError ()
307+ << " cannot bufferize a FuncOp with tensors and "
308+ " without a unique ReturnOp" ;
321309 }
322310
323311 // Collect function calls and populate the caller map.
324312 numberCallOpsContainedInFuncOp[funcOp] = 0 ;
325313 return funcOp.walk ([&](CallOpInterface callOp) -> WalkResult {
326314 FunctionOpInterface calledFunction = getCalledFunction (callOp);
327- assert (calledFunction && " could not retrieved called func::FuncOp " );
315+ assert (calledFunction && " could not retrieved called FunctionOpInterface " );
328316 // If the called function does not have any tensors in its signature, then
329317 // it is not necessary to bufferize the callee before the caller.
330318 if (!hasTensorSignature (calledFunction))
@@ -362,11 +350,15 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
362350// / most generic layout map as function return types. After bufferizing the
363351// / entire function body, a more concise memref type can potentially be used for
364352// / the return type of the function.
365- static void foldMemRefCasts (func::FuncOp funcOp) {
366- if (funcOp.getBody ().empty ())
353+ static void foldMemRefCasts (FunctionOpInterface funcOp) {
354+ if (funcOp.getFunctionBody ().empty ())
355+ return ;
356+
357+ Operation *returnOp = getAssumedUniqueReturnOp (funcOp);
358+
359+ if (!returnOp)
367360 return ;
368361
369- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
370362 SmallVector<Type> resultTypes;
371363
372364 for (OpOperand &operand : returnOp->getOpOperands ()) {
@@ -379,7 +371,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
379371 }
380372
381373 auto newFuncType = FunctionType::get (
382- funcOp.getContext (), funcOp.getFunctionType (). getInputs (), resultTypes);
374+ funcOp.getContext (), funcOp.getArgumentTypes (), resultTypes);
383375 funcOp.setType (newFuncType);
384376}
385377
@@ -403,39 +395,34 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
403395 // Analyze ops.
404396 for (FunctionOpInterface funcOp : orderedFuncOps) {
405397
406- // The following analysis is specific to the FuncOp type.
407- if (!isa<FuncOp>(funcOp))
408- continue ;
409- FuncOp funcOpCasted = cast<func::FuncOp>(funcOp);
410-
411- if (!state.getOptions ().isOpAllowed (funcOpCasted))
398+ if (!state.getOptions ().isOpAllowed (funcOp))
412399 continue ;
413400
414401 // Now analyzing function.
415- funcState.startFunctionAnalysis (funcOpCasted );
402+ funcState.startFunctionAnalysis (funcOp );
416403
417404 // Gather equivalence info for CallOps.
418- equivalenceAnalysis (funcOpCasted , state, funcState);
405+ equivalenceAnalysis (funcOp , state, funcState);
419406
420407 // Analyze funcOp.
421- if (failed (analyzeOp (funcOpCasted , state, statistics)))
408+ if (failed (analyzeOp (funcOp , state, statistics)))
422409 return failure ();
423410
424411 // Run some extra function analyses.
425- if (failed (aliasingFuncOpBBArgsAnalysis (funcOpCasted , state, funcState)) ||
426- failed (funcOpBbArgReadWriteAnalysis (funcOpCasted , state, funcState)))
412+ if (failed (aliasingFuncOpBBArgsAnalysis (funcOp , state, funcState)) ||
413+ failed (funcOpBbArgReadWriteAnalysis (funcOp , state, funcState)))
427414 return failure ();
428415
429416 // Mark op as fully analyzed.
430- funcState.analyzedFuncOps [funcOpCasted ] = FuncOpAnalysisState::Analyzed;
417+ funcState.analyzedFuncOps [funcOp ] = FuncOpAnalysisState::Analyzed;
431418 }
432419
433420 return success ();
434421}
435422
436423void mlir::bufferization::removeBufferizationAttributesInModule (
437424 ModuleOp moduleOp) {
438- moduleOp.walk ([&](func::FuncOp op) {
425+ moduleOp.walk ([&](FunctionOpInterface op) {
439426 for (BlockArgument bbArg : op.getArguments ())
440427 removeBufferizationAttributes (bbArg);
441428 });
@@ -475,14 +462,14 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
475462 }
476463
477464 // Change buffer return types to more precise layout maps.
478- if (options.inferFunctionResultLayout && isa<func::FuncOp>(funcOp) )
479- foldMemRefCasts (cast<func::FuncOp>( funcOp) );
465+ if (options.inferFunctionResultLayout )
466+ foldMemRefCasts (funcOp);
480467 }
481468
482469 // Bufferize all other ops.
483470 for (Operation &op : llvm::make_early_inc_range (moduleOp.getOps ())) {
484471 // Functions were already bufferized.
485- if (isa<func::FuncOp >(&op))
472+ if (isa<FunctionOpInterface >(&op))
486473 continue ;
487474 if (failed (bufferizeOp (&op, options, statistics)))
488475 return failure ();
@@ -509,12 +496,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
509496 // FuncOps whose names are specified in options.noAnalysisFuncFilter will
510497 // not be analyzed. Ops in these FuncOps will not be analyzed as well.
511498 OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
512- auto func = dyn_cast<func::FuncOp >(op);
499+ auto func = dyn_cast<FunctionOpInterface >(op);
513500 if (!func)
514- func = op->getParentOfType <func::FuncOp >();
501+ func = op->getParentOfType <FunctionOpInterface >();
515502 if (func)
516503 return llvm::is_contained (options.noAnalysisFuncFilter ,
517- func.getSymName ());
504+ func.getName ());
518505 return false ;
519506 };
520507 OneShotBufferizationOptions updatedOptions (options);
0 commit comments