@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
7575using namespace mlir ::bufferization::func_ext;
7676
7777// / A mapping of FuncOps to their callers.
78- using FuncCallerMap = DenseMap<func::FuncOp , DenseSet<Operation *>>;
78+ using FuncCallerMap = DenseMap<FunctionOpInterface , DenseSet<Operation *>>;
7979
8080// / Get or create FuncAnalysisState.
8181static FuncAnalysisState &
@@ -88,10 +88,11 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8888
8989// / Return the unique ReturnOp that terminates `funcOp`.
9090// / Return nullptr if there is no such unique ReturnOp.
91- static func::ReturnOp getAssumedUniqueReturnOp (func::FuncOp funcOp) {
92- func::ReturnOp returnOp;
93- for (Block &b : funcOp.getBody ()) {
94- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator ())) {
91+ static Operation *getAssumedUniqueReturnOp (FunctionOpInterface funcOp) {
92+ Operation *returnOp = nullptr ;
93+ for (Block &b : funcOp.getFunctionBody ()) {
94+ auto candidateOp = b.getTerminator ();
95+ if (candidateOp && candidateOp->hasTrait <OpTrait::ReturnLike>()) {
9596 if (returnOp)
9697 return nullptr ;
9798 returnOp = candidateOp;
@@ -126,16 +127,16 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
126127// / Store function BlockArguments that are equivalent to/aliasing a returned
127128// / value in FuncAnalysisState.
128129static LogicalResult
129- aliasingFuncOpBBArgsAnalysis (FuncOp funcOp, OneShotAnalysisState &state,
130+ aliasingFuncOpBBArgsAnalysis (FunctionOpInterface funcOp,
131+ OneShotAnalysisState &state,
130132 FuncAnalysisState &funcState) {
131- if (funcOp.getBody ().empty ()) {
133+ if (funcOp.getFunctionBody ().empty ()) {
132134 // No function body available. Conservatively assume that every tensor
133135 // return value may alias with any tensor bbArg.
134- FunctionType type = funcOp.getFunctionType ();
135- for (const auto &inputIt : llvm::enumerate (type.getInputs ())) {
136+ for (const auto &inputIt : llvm::enumerate (funcOp.getArgumentTypes ())) {
136137 if (!isa<TensorType>(inputIt.value ()))
137138 continue ;
138- for (const auto &resultIt : llvm::enumerate (type. getResults ())) {
139+ for (const auto &resultIt : llvm::enumerate (funcOp. getResultTypes ())) {
139140 if (!isa<TensorType>(resultIt.value ()))
140141 continue ;
141142 int64_t returnIdx = resultIt.index ();
@@ -147,7 +148,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
147148 }
148149
149150 // Support only single return-terminated block in the function.
150- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
151+ Operation * returnOp = getAssumedUniqueReturnOp (funcOp);
151152 assert (returnOp && " expected func with single return op" );
152153
153154 for (OpOperand &returnVal : returnOp->getOpOperands ())
@@ -168,8 +169,8 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
168169 return success ();
169170}
170171
171- static void annotateFuncArgAccess (func::FuncOp funcOp, int64_t idx, bool isRead ,
172- bool isWritten) {
172+ static void annotateFuncArgAccess (FunctionOpInterface funcOp, int64_t idx,
173+ bool isRead, bool isWritten) {
173174 OpBuilder b (funcOp.getContext ());
174175 Attribute accessType;
175176 if (isRead && isWritten) {
@@ -189,12 +190,12 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
189190// / function with unknown ops, we conservatively assume that such ops bufferize
190191// / to a read + write.
191192static LogicalResult
192- funcOpBbArgReadWriteAnalysis (FuncOp funcOp, OneShotAnalysisState &state,
193+ funcOpBbArgReadWriteAnalysis (FunctionOpInterface funcOp,
194+ OneShotAnalysisState &state,
193195 FuncAnalysisState &funcState) {
194- for (int64_t idx = 0 , e = funcOp.getFunctionType ().getNumInputs (); idx < e;
195- ++idx) {
196+ for (int64_t idx = 0 , e = funcOp.getNumArguments (); idx < e; ++idx) {
196197 // Skip non-tensor arguments.
197- if (!isa<TensorType>(funcOp.getFunctionType (). getInput ( idx) ))
198+ if (!isa<TensorType>(funcOp.getArgumentTypes ()[ idx] ))
198199 continue ;
199200 bool isRead;
200201 bool isWritten;
@@ -204,7 +205,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
204205 StringRef str = accessAttr.getValue ();
205206 isRead = str == " read" || str == " read-write" ;
206207 isWritten = str == " write" || str == " read-write" ;
207- } else if (funcOp.getBody ().empty ()) {
208+ } else if (funcOp.getFunctionBody ().empty ()) {
208209 // If the function has no body, conservatively assume that all args are
209210 // read + written.
210211 isRead = true ;
@@ -230,33 +231,32 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
230231
231232// / Remove bufferization attributes on FuncOp arguments.
232233static void removeBufferizationAttributes (BlockArgument bbArg) {
233- auto funcOp = cast<func::FuncOp >(bbArg.getOwner ()->getParentOp ());
234+ auto funcOp = cast<FunctionOpInterface >(bbArg.getOwner ()->getParentOp ());
234235 funcOp.removeArgAttr (bbArg.getArgNumber (),
235236 BufferizationDialect::kBufferLayoutAttrName );
236237 funcOp.removeArgAttr (bbArg.getArgNumber (),
237238 BufferizationDialect::kWritableAttrName );
238239}
239240
240- // / Return the func::FuncOp called by `callOp`.
241- static func::FuncOp getCalledFunction (func::CallOp callOp) {
241+ static FunctionOpInterface getCalledFunction (CallOpInterface callOp) {
242242 SymbolRefAttr sym =
243243 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
244244 if (!sym)
245245 return nullptr ;
246- return dyn_cast_or_null<func::FuncOp >(
246+ return dyn_cast_or_null<FunctionOpInterface >(
247247 SymbolTable::lookupNearestSymbolFrom (callOp, sym));
248248}
249249
250250// / Gather equivalence info of CallOps.
251251// / Note: This only adds new equivalence info if the called function was already
252252// / analyzed.
253253// TODO: This does not handle cyclic function call graphs etc.
254- static void equivalenceAnalysis (func::FuncOp funcOp,
254+ static void equivalenceAnalysis (FunctionOpInterface funcOp,
255255 OneShotAnalysisState &state,
256256 FuncAnalysisState &funcState) {
257- funcOp->walk ([&](func::CallOp callOp) {
258- func::FuncOp calledFunction = getCalledFunction (callOp);
259- assert (calledFunction && " could not retrieved called func::FuncOp " );
257+ funcOp->walk ([&](CallOpInterface callOp) {
258+ FunctionOpInterface calledFunction = getCalledFunction (callOp);
259+ assert (calledFunction && " could not retrieved called FunctionOpInterface " );
260260
261261 // No equivalence info available for the called function.
262262 if (!funcState.equivalentFuncArgs .count (calledFunction))
@@ -267,7 +267,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
267267 int64_t bbargIdx = it.second ;
268268 if (!state.isInPlace (callOp->getOpOperand (bbargIdx)))
269269 continue ;
270- Value returnVal = callOp. getResult (returnIdx);
270+ Value returnVal = callOp-> getResult (returnIdx);
271271 Value argVal = callOp->getOperand (bbargIdx);
272272 state.unionEquivalenceClasses (returnVal, argVal);
273273 }
@@ -277,11 +277,9 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
277277}
278278
279279// / Return "true" if the given function signature has tensor semantics.
280- static bool hasTensorSignature (func::FuncOp funcOp) {
281- return llvm::any_of (funcOp.getFunctionType ().getInputs (),
282- llvm::IsaPred<TensorType>) ||
283- llvm::any_of (funcOp.getFunctionType ().getResults (),
284- llvm::IsaPred<TensorType>);
280+ static bool hasTensorSignature (FunctionOpInterface funcOp) {
281+ return llvm::any_of (funcOp.getArgumentTypes (), llvm::IsaPred<TensorType>) ||
282+ llvm::any_of (funcOp.getResultTypes (), llvm::IsaPred<TensorType>);
285283}
286284
287285// / Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
@@ -291,16 +289,16 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
291289// / retrieve the called FuncOp from any func::CallOp.
292290static LogicalResult
293291getFuncOpsOrderedByCalls (ModuleOp moduleOp,
294- SmallVectorImpl<func::FuncOp > &orderedFuncOps,
292+ SmallVectorImpl<FunctionOpInterface > &orderedFuncOps,
295293 FuncCallerMap &callerMap) {
296294 // For each FuncOp, the set of functions called by it (i.e. the union of
297295 // symbols of all nested func::CallOp).
298- DenseMap<func::FuncOp , DenseSet<func::FuncOp >> calledBy;
296+ DenseMap<FunctionOpInterface , DenseSet<FunctionOpInterface >> calledBy;
299297 // For each FuncOp, the number of func::CallOp it contains.
300- DenseMap<func::FuncOp , unsigned > numberCallOpsContainedInFuncOp;
301- WalkResult res = moduleOp.walk ([&](func::FuncOp funcOp) -> WalkResult {
302- if (!funcOp.getBody ().empty ()) {
303- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
298+ DenseMap<FunctionOpInterface , unsigned > numberCallOpsContainedInFuncOp;
299+ WalkResult res = moduleOp.walk ([&](FunctionOpInterface funcOp) -> WalkResult {
300+ if (!funcOp.getFunctionBody ().empty ()) {
301+ Operation * returnOp = getAssumedUniqueReturnOp (funcOp);
304302 if (!returnOp)
305303 return funcOp->emitError ()
306304 << " cannot bufferize a FuncOp with tensors and "
@@ -309,9 +307,10 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
309307
310308 // Collect function calls and populate the caller map.
311309 numberCallOpsContainedInFuncOp[funcOp] = 0 ;
312- return funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
313- func::FuncOp calledFunction = getCalledFunction (callOp);
314- assert (calledFunction && " could not retrieved called func::FuncOp" );
310+ return funcOp.walk ([&](CallOpInterface callOp) -> WalkResult {
311+ FunctionOpInterface calledFunction = getCalledFunction (callOp);
312+ assert (calledFunction &&
313+ " could not retrieved called FunctionOpInterface" );
315314 // If the called function does not have any tensors in its signature, then
316315 // it is not necessary to bufferize the callee before the caller.
317316 if (!hasTensorSignature (calledFunction))
@@ -349,11 +348,11 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
349348// / most generic layout map as function return types. After bufferizing the
350349// / entire function body, a more concise memref type can potentially be used for
351350// / the return type of the function.
352- static void foldMemRefCasts (func::FuncOp funcOp) {
353- if (funcOp.getBody ().empty ())
351+ static void foldMemRefCasts (FunctionOpInterface funcOp) {
352+ if (funcOp.getFunctionBody ().empty ())
354353 return ;
355354
356- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
355+ Operation * returnOp = getAssumedUniqueReturnOp (funcOp);
357356 SmallVector<Type> resultTypes;
358357
359358 for (OpOperand &operand : returnOp->getOpOperands ()) {
@@ -365,8 +364,8 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
365364 }
366365 }
367366
368- auto newFuncType = FunctionType::get (
369- funcOp. getContext (), funcOp.getFunctionType (). getInputs (), resultTypes);
367+ auto newFuncType = FunctionType::get (funcOp. getContext (),
368+ funcOp.getArgumentTypes (), resultTypes);
370369 funcOp.setType (newFuncType);
371370}
372371
@@ -379,7 +378,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
379378 FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
380379
381380 // A list of functions in the order in which they are analyzed + bufferized.
382- SmallVector<func::FuncOp > orderedFuncOps;
381+ SmallVector<FunctionOpInterface > orderedFuncOps;
383382
384383 // A mapping of FuncOps to their callers.
385384 FuncCallerMap callerMap;
@@ -388,7 +387,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
388387 return failure ();
389388
390389 // Analyze ops.
391- for (func::FuncOp funcOp : orderedFuncOps) {
390+ for (FunctionOpInterface funcOp : orderedFuncOps) {
392391 if (!state.getOptions ().isOpAllowed (funcOp))
393392 continue ;
394393
@@ -416,7 +415,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
416415
417416void mlir::bufferization::removeBufferizationAttributesInModule (
418417 ModuleOp moduleOp) {
419- moduleOp.walk ([&](func::FuncOp op) {
418+ moduleOp.walk ([&](FunctionOpInterface op) {
420419 for (BlockArgument bbArg : op.getArguments ())
421420 removeBufferizationAttributes (bbArg);
422421 });
@@ -430,7 +429,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
430429 IRRewriter rewriter (moduleOp.getContext ());
431430
432431 // A list of functions in the order in which they are analyzed + bufferized.
433- SmallVector<func::FuncOp > orderedFuncOps;
432+ SmallVector<FunctionOpInterface > orderedFuncOps;
434433
435434 // A mapping of FuncOps to their callers.
436435 FuncCallerMap callerMap;
@@ -439,11 +438,11 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
439438 return failure ();
440439
441440 // Bufferize functions.
442- for (func::FuncOp funcOp : orderedFuncOps) {
441+ for (FunctionOpInterface funcOp : orderedFuncOps) {
443442 // Note: It would be good to apply cleanups here but we cannot as aliasInfo
444443 // would be invalidated.
445444
446- if (llvm::is_contained (options.noAnalysisFuncFilter , funcOp.getSymName ())) {
445+ if (llvm::is_contained (options.noAnalysisFuncFilter , funcOp.getName ())) {
447446 // This function was not analyzed and RaW conflicts were not resolved.
448447 // Buffer copies must be inserted before every write.
449448 OneShotBufferizationOptions updatedOptions = options;
@@ -463,7 +462,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
463462 // Bufferize all other ops.
464463 for (Operation &op : llvm::make_early_inc_range (moduleOp.getOps ())) {
465464 // Functions were already bufferized.
466- if (isa<func::FuncOp >(&op))
465+ if (isa<FunctionOpInterface >(&op))
467466 continue ;
468467 if (failed (bufferizeOp (&op, options, statistics)))
469468 return failure ();
@@ -490,12 +489,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
490489 // FuncOps whose names are specified in options.noAnalysisFuncFilter will
491490 // not be analyzed. Ops in these FuncOps will not be analyzed as well.
492491 OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
493- auto func = dyn_cast<func::FuncOp >(op);
492+ auto func = dyn_cast<FunctionOpInterface >(op);
494493 if (!func)
495- func = op->getParentOfType <func::FuncOp >();
494+ func = op->getParentOfType <FunctionOpInterface >();
496495 if (func)
497496 return llvm::is_contained (options.noAnalysisFuncFilter ,
498- func.getSymName ());
497+ func.getName ());
499498 return false ;
500499 };
501500 OneShotBufferizationOptions updatedOptions (options);
0 commit comments