@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
7575using namespace mlir ::bufferization::func_ext;
7676
7777// / A mapping of FuncOps to their callers.
78- using FuncCallerMap = DenseMap<FunctionOpInterface , DenseSet<Operation *>>;
78+ using FuncCallerMap = DenseMap<func::FuncOp , DenseSet<Operation *>>;
7979
8080// / Get or create FuncAnalysisState.
8181static FuncAnalysisState &
@@ -88,11 +88,10 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8888
8989// / Return the unique ReturnOp that terminates `funcOp`.
9090// / Return nullptr if there is no such unique ReturnOp.
91- static Operation *getAssumedUniqueReturnOp (FunctionOpInterface funcOp) {
92- Operation *returnOp = nullptr ;
93- for (Block &b : funcOp.getFunctionBody ()) {
94- auto candidateOp = b.getTerminator ();
95- if (candidateOp && candidateOp->hasTrait <OpTrait::ReturnLike>()) {
91+ static func::ReturnOp getAssumedUniqueReturnOp (func::FuncOp funcOp) {
92+ func::ReturnOp returnOp;
93+ for (Block &b : funcOp.getBody ()) {
94+ if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator ())) {
9695 if (returnOp)
9796 return nullptr ;
9897 returnOp = candidateOp;
@@ -127,16 +126,16 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
127126// / Store function BlockArguments that are equivalent to/aliasing a returned
128127// / value in FuncAnalysisState.
129128static LogicalResult
130- aliasingFuncOpBBArgsAnalysis (FunctionOpInterface funcOp,
131- OneShotAnalysisState &state,
129+ aliasingFuncOpBBArgsAnalysis (FuncOp funcOp, OneShotAnalysisState &state,
132130 FuncAnalysisState &funcState) {
133- if (funcOp.getFunctionBody ().empty ()) {
131+ if (funcOp.getBody ().empty ()) {
134132 // No function body available. Conservatively assume that every tensor
135133 // return value may alias with any tensor bbArg.
136- for (const auto &inputIt : llvm::enumerate (funcOp.getArgumentTypes ())) {
134+ FunctionType type = funcOp.getFunctionType ();
135+ for (const auto &inputIt : llvm::enumerate (type.getInputs ())) {
137136 if (!isa<TensorType>(inputIt.value ()))
138137 continue ;
139- for (const auto &resultIt : llvm::enumerate (funcOp. getResultTypes ())) {
138+ for (const auto &resultIt : llvm::enumerate (type. getResults ())) {
140139 if (!isa<TensorType>(resultIt.value ()))
141140 continue ;
142141 int64_t returnIdx = resultIt.index ();
@@ -148,7 +147,7 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
148147 }
149148
150149 // Support only single return-terminated block in the function.
151- Operation * returnOp = getAssumedUniqueReturnOp (funcOp);
150+ func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
152151 assert (returnOp && " expected func with single return op" );
153152
154153 for (OpOperand &returnVal : returnOp->getOpOperands ())
@@ -169,8 +168,8 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
169168 return success ();
170169}
171170
172- static void annotateFuncArgAccess (FunctionOpInterface funcOp, int64_t idx,
173- bool isRead, bool isWritten) {
171+ static void annotateFuncArgAccess (func::FuncOp funcOp, int64_t idx, bool isRead ,
172+ bool isWritten) {
174173 OpBuilder b (funcOp.getContext ());
175174 Attribute accessType;
176175 if (isRead && isWritten) {
@@ -190,12 +189,12 @@ static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx,
190189// / function with unknown ops, we conservatively assume that such ops bufferize
191190// / to a read + write.
192191static LogicalResult
193- funcOpBbArgReadWriteAnalysis (FunctionOpInterface funcOp,
194- OneShotAnalysisState &state,
192+ funcOpBbArgReadWriteAnalysis (FuncOp funcOp, OneShotAnalysisState &state,
195193 FuncAnalysisState &funcState) {
196- for (int64_t idx = 0 , e = funcOp.getNumArguments (); idx < e; ++idx) {
194+ for (int64_t idx = 0 , e = funcOp.getFunctionType ().getNumInputs (); idx < e;
195+ ++idx) {
197196 // Skip non-tensor arguments.
198- if (!isa<TensorType>(funcOp.getArgumentTypes ()[ idx] ))
197+ if (!isa<TensorType>(funcOp.getFunctionType (). getInput ( idx) ))
199198 continue ;
200199 bool isRead;
201200 bool isWritten;
@@ -205,7 +204,7 @@ funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
205204 StringRef str = accessAttr.getValue ();
206205 isRead = str == " read" || str == " read-write" ;
207206 isWritten = str == " write" || str == " read-write" ;
208- } else if (funcOp.getFunctionBody ().empty ()) {
207+ } else if (funcOp.getBody ().empty ()) {
209208 // If the function has no body, conservatively assume that all args are
210209 // read + written.
211210 isRead = true ;
@@ -231,33 +230,33 @@ funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
231230
232231// / Remove bufferization attributes on FuncOp arguments.
233232static void removeBufferizationAttributes (BlockArgument bbArg) {
234- auto funcOp = cast<FunctionOpInterface >(bbArg.getOwner ()->getParentOp ());
233+ auto funcOp = cast<func::FuncOp >(bbArg.getOwner ()->getParentOp ());
235234 funcOp.removeArgAttr (bbArg.getArgNumber (),
236235 BufferizationDialect::kBufferLayoutAttrName );
237236 funcOp.removeArgAttr (bbArg.getArgNumber (),
238237 BufferizationDialect::kWritableAttrName );
239238}
240239
241- static FunctionOpInterface getCalledFunction (CallOpInterface callOp) {
240+ // / Return the func::FuncOp called by `callOp`.
241+ static func::FuncOp getCalledFunction (func::CallOp callOp) {
242242 SymbolRefAttr sym =
243243 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
244244 if (!sym)
245245 return nullptr ;
246- return dyn_cast_or_null<FunctionOpInterface >(
246+ return dyn_cast_or_null<func::FuncOp >(
247247 SymbolTable::lookupNearestSymbolFrom (callOp, sym));
248248}
249249
250250// / Gather equivalence info of CallOps.
251251// / Note: This only adds new equivalence info if the called function was already
252252// / analyzed.
253253// TODO: This does not handle cyclic function call graphs etc.
254- static void equivalenceAnalysis (FunctionOpInterface funcOp,
254+ static void equivalenceAnalysis (func::FuncOp funcOp,
255255 OneShotAnalysisState &state,
256256 FuncAnalysisState &funcState) {
257- funcOp->walk ([&](CallOpInterface callOp) {
258- FunctionOpInterface calledFunction = getCalledFunction (callOp);
259- if (!calledFunction)
260- return WalkResult::skip ();
257+ funcOp->walk ([&](func::CallOp callOp) {
258+ func::FuncOp calledFunction = getCalledFunction (callOp);
259+ assert (calledFunction && " could not retrieved called func::FuncOp" );
261260
262261 // No equivalence info available for the called function.
263262 if (!funcState.equivalentFuncArgs .count (calledFunction))
@@ -268,7 +267,7 @@ static void equivalenceAnalysis(FunctionOpInterface funcOp,
268267 int64_t bbargIdx = it.second ;
269268 if (!state.isInPlace (callOp->getOpOperand (bbargIdx)))
270269 continue ;
271- Value returnVal = callOp-> getResult (returnIdx);
270+ Value returnVal = callOp. getResult (returnIdx);
272271 Value argVal = callOp->getOperand (bbargIdx);
273272 state.unionEquivalenceClasses (returnVal, argVal);
274273 }
@@ -278,9 +277,11 @@ static void equivalenceAnalysis(FunctionOpInterface funcOp,
278277}
279278
280279// / Return "true" if the given function signature has tensor semantics.
281- static bool hasTensorSignature (FunctionOpInterface funcOp) {
282- return llvm::any_of (funcOp.getArgumentTypes (), llvm::IsaPred<TensorType>) ||
283- llvm::any_of (funcOp.getResultTypes (), llvm::IsaPred<TensorType>);
280+ static bool hasTensorSignature (func::FuncOp funcOp) {
281+ return llvm::any_of (funcOp.getFunctionType ().getInputs (),
282+ llvm::IsaPred<TensorType>) ||
283+ llvm::any_of (funcOp.getFunctionType ().getResults (),
284+ llvm::IsaPred<TensorType>);
284285}
285286
286287// / Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
@@ -290,16 +291,16 @@ static bool hasTensorSignature(FunctionOpInterface funcOp) {
290291// / retrieve the called FuncOp from any func::CallOp.
291292static LogicalResult
292293getFuncOpsOrderedByCalls (ModuleOp moduleOp,
293- SmallVectorImpl<FunctionOpInterface > &orderedFuncOps,
294+ SmallVectorImpl<func::FuncOp > &orderedFuncOps,
294295 FuncCallerMap &callerMap) {
295296 // For each FuncOp, the set of functions called by it (i.e. the union of
296297 // symbols of all nested func::CallOp).
297- DenseMap<FunctionOpInterface , DenseSet<FunctionOpInterface >> calledBy;
298+ DenseMap<func::FuncOp , DenseSet<func::FuncOp >> calledBy;
298299 // For each FuncOp, the number of func::CallOp it contains.
299- DenseMap<FunctionOpInterface , unsigned > numberCallOpsContainedInFuncOp;
300- WalkResult res = moduleOp.walk ([&](FunctionOpInterface funcOp) -> WalkResult {
301- if (!funcOp.getFunctionBody ().empty ()) {
302- Operation * returnOp = getAssumedUniqueReturnOp (funcOp);
300+ DenseMap<func::FuncOp , unsigned > numberCallOpsContainedInFuncOp;
301+ WalkResult res = moduleOp.walk ([&](func::FuncOp funcOp) -> WalkResult {
302+ if (!funcOp.getBody ().empty ()) {
303+ func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
303304 if (!returnOp)
304305 return funcOp->emitError ()
305306 << " cannot bufferize a FuncOp with tensors and "
@@ -308,10 +309,9 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
308309
309310 // Collect function calls and populate the caller map.
310311 numberCallOpsContainedInFuncOp[funcOp] = 0 ;
311- return funcOp.walk ([&](CallOpInterface callOp) -> WalkResult {
312- FunctionOpInterface calledFunction = getCalledFunction (callOp);
313- if (!calledFunction)
314- return WalkResult::skip ();
312+ return funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
313+ func::FuncOp calledFunction = getCalledFunction (callOp);
314+ assert (calledFunction && " could not retrieved called func::FuncOp" );
315315 // If the called function does not have any tensors in its signature, then
316316 // it is not necessary to bufferize the callee before the caller.
317317 if (!hasTensorSignature (calledFunction))
@@ -349,11 +349,11 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
349349// / most generic layout map as function return types. After bufferizing the
350350// / entire function body, a more concise memref type can potentially be used for
351351// / the return type of the function.
352- static void foldMemRefCasts (FunctionOpInterface funcOp) {
353- if (funcOp.getFunctionBody ().empty ())
352+ static void foldMemRefCasts (func::FuncOp funcOp) {
353+ if (funcOp.getBody ().empty ())
354354 return ;
355355
356- Operation * returnOp = getAssumedUniqueReturnOp (funcOp);
356+ func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
357357 SmallVector<Type> resultTypes;
358358
359359 for (OpOperand &operand : returnOp->getOpOperands ()) {
@@ -365,8 +365,8 @@ static void foldMemRefCasts(FunctionOpInterface funcOp) {
365365 }
366366 }
367367
368- auto newFuncType = FunctionType::get (funcOp. getContext (),
369- funcOp.getArgumentTypes (), resultTypes);
368+ auto newFuncType = FunctionType::get (
369+ funcOp. getContext (), funcOp.getFunctionType (). getInputs (), resultTypes);
370370 funcOp.setType (newFuncType);
371371}
372372
@@ -379,7 +379,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
379379 FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
380380
381381 // A list of functions in the order in which they are analyzed + bufferized.
382- SmallVector<FunctionOpInterface > orderedFuncOps;
382+ SmallVector<func::FuncOp > orderedFuncOps;
383383
384384 // A mapping of FuncOps to their callers.
385385 FuncCallerMap callerMap;
@@ -388,7 +388,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
388388 return failure ();
389389
390390 // Analyze ops.
391- for (FunctionOpInterface funcOp : orderedFuncOps) {
391+ for (func::FuncOp funcOp : orderedFuncOps) {
392392 if (!state.getOptions ().isOpAllowed (funcOp))
393393 continue ;
394394
@@ -416,7 +416,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
416416
417417void mlir::bufferization::removeBufferizationAttributesInModule (
418418 ModuleOp moduleOp) {
419- moduleOp.walk ([&](FunctionOpInterface op) {
419+ moduleOp.walk ([&](func::FuncOp op) {
420420 for (BlockArgument bbArg : op.getArguments ())
421421 removeBufferizationAttributes (bbArg);
422422 });
@@ -430,7 +430,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
430430 IRRewriter rewriter (moduleOp.getContext ());
431431
432432 // A list of functions in the order in which they are analyzed + bufferized.
433- SmallVector<FunctionOpInterface > orderedFuncOps;
433+ SmallVector<func::FuncOp > orderedFuncOps;
434434
435435 // A mapping of FuncOps to their callers.
436436 FuncCallerMap callerMap;
@@ -439,11 +439,11 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
439439 return failure ();
440440
441441 // Bufferize functions.
442- for (FunctionOpInterface funcOp : orderedFuncOps) {
442+ for (func::FuncOp funcOp : orderedFuncOps) {
443443 // Note: It would be good to apply cleanups here but we cannot as aliasInfo
444444 // would be invalidated.
445445
446- if (llvm::is_contained (options.noAnalysisFuncFilter , funcOp.getName ())) {
446+ if (llvm::is_contained (options.noAnalysisFuncFilter , funcOp.getSymName ())) {
447447 // This function was not analyzed and RaW conflicts were not resolved.
448448 // Buffer copies must be inserted before every write.
449449 OneShotBufferizationOptions updatedOptions = options;
@@ -463,7 +463,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
463463 // Bufferize all other ops.
464464 for (Operation &op : llvm::make_early_inc_range (moduleOp.getOps ())) {
465465 // Functions were already bufferized.
466- if (isa<FunctionOpInterface >(&op))
466+ if (isa<func::FuncOp >(&op))
467467 continue ;
468468 if (failed (bufferizeOp (&op, options, statistics)))
469469 return failure ();
@@ -490,12 +490,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
490490 // FuncOps whose names are specified in options.noAnalysisFuncFilter will
491491 // not be analyzed. Ops in these FuncOps will not be analyzed as well.
492492 OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
493- auto func = dyn_cast<FunctionOpInterface >(op);
493+ auto func = dyn_cast<func::FuncOp >(op);
494494 if (!func)
495- func = op->getParentOfType <FunctionOpInterface >();
495+ func = op->getParentOfType <func::FuncOp >();
496496 if (func)
497497 return llvm::is_contained (options.noAnalysisFuncFilter ,
498- func.getName ());
498+ func.getSymName ());
499499 return false ;
500500 };
501501 OneShotBufferizationOptions updatedOptions (options);
0 commit comments