@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
75
75
using namespace mlir ::bufferization::func_ext;
76
76
77
77
// / A mapping of FuncOps to their callers.
78
- using FuncCallerMap = DenseMap<func::FuncOp , DenseSet<Operation *>>;
78
+ using FuncCallerMap = DenseMap<FunctionOpInterface , DenseSet<Operation *>>;
79
79
80
80
// / Get or create FuncAnalysisState.
81
81
static FuncAnalysisState &
@@ -247,6 +247,15 @@ static func::FuncOp getCalledFunction(func::CallOp callOp) {
247
247
SymbolTable::lookupNearestSymbolFrom (callOp, sym));
248
248
}
249
249
250
+ static FunctionOpInterface getCalledFunction (CallOpInterface callOp) {
251
+ SymbolRefAttr sym =
252
+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
253
+ if (!sym)
254
+ return nullptr ;
255
+ return dyn_cast_or_null<FunctionOpInterface>(
256
+ SymbolTable::lookupNearestSymbolFrom (callOp, sym));
257
+ }
258
+
250
259
// / Gather equivalence info of CallOps.
251
260
// / Note: This only adds new equivalence info if the called function was already
252
261
// / analyzed.
@@ -277,10 +286,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
277
286
}
278
287
279
288
// / Return "true" if the given function signature has tensor semantics.
280
- static bool hasTensorSignature (func::FuncOp funcOp) {
281
- return llvm::any_of (funcOp.getFunctionType (). getInputs (),
289
+ static bool hasTensorSignature (FunctionOpInterface funcOp) {
290
+ return llvm::any_of (funcOp.getArgumentTypes (),
282
291
llvm::IsaPred<TensorType>) ||
283
- llvm::any_of (funcOp.getFunctionType (). getResults (),
292
+ llvm::any_of (funcOp.getResultTypes (),
284
293
llvm::IsaPred<TensorType>);
285
294
}
286
295
@@ -291,26 +300,30 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
291
300
// / retrieve the called FuncOp from any func::CallOp.
292
301
static LogicalResult
293
302
getFuncOpsOrderedByCalls (ModuleOp moduleOp,
294
- SmallVectorImpl<func::FuncOp > &orderedFuncOps,
303
+ SmallVectorImpl<FunctionOpInterface > &orderedFuncOps,
295
304
FuncCallerMap &callerMap) {
296
305
// For each FuncOp, the set of functions called by it (i.e. the union of
297
306
// symbols of all nested func::CallOp).
298
- DenseMap<func::FuncOp , DenseSet<func::FuncOp >> calledBy;
307
+ DenseMap<FunctionOpInterface , DenseSet<FunctionOpInterface >> calledBy;
299
308
// For each FuncOp, the number of func::CallOp it contains.
300
- DenseMap<func::FuncOp, unsigned > numberCallOpsContainedInFuncOp;
301
- WalkResult res = moduleOp.walk ([&](func::FuncOp funcOp) -> WalkResult {
302
- if (!funcOp.getBody ().empty ()) {
303
- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
304
- if (!returnOp)
305
- return funcOp->emitError ()
306
- << " cannot bufferize a FuncOp with tensors and "
307
- " without a unique ReturnOp" ;
309
+ DenseMap<FunctionOpInterface, unsigned > numberCallOpsContainedInFuncOp;
310
+ WalkResult res = moduleOp.walk ([&](FunctionOpInterface funcOp) -> WalkResult {
311
+ // Only handle ReturnOp if funcOp is exactly the FuncOp type.
312
+ if (isa<FuncOp>(funcOp)) {
313
+ FuncOp funcOpCasted = cast<FuncOp>(funcOp);
314
+ if (!funcOpCasted.getBody ().empty ()) {
315
+ func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOpCasted);
316
+ if (!returnOp)
317
+ return funcOp->emitError ()
318
+ << " cannot bufferize a FuncOp with tensors and "
319
+ " without a unique ReturnOp" ;
320
+ }
308
321
}
309
322
310
323
// Collect function calls and populate the caller map.
311
324
numberCallOpsContainedInFuncOp[funcOp] = 0 ;
312
- return funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
313
- func::FuncOp calledFunction = getCalledFunction (callOp);
325
+ return funcOp.walk ([&](CallOpInterface callOp) -> WalkResult {
326
+ FunctionOpInterface calledFunction = getCalledFunction (callOp);
314
327
assert (calledFunction && " could not retrieved called func::FuncOp" );
315
328
// If the called function does not have any tensors in its signature, then
316
329
// it is not necessary to bufferize the callee before the caller.
@@ -379,7 +392,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
379
392
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
380
393
381
394
// A list of functions in the order in which they are analyzed + bufferized.
382
- SmallVector<func::FuncOp > orderedFuncOps;
395
+ SmallVector<FunctionOpInterface > orderedFuncOps;
383
396
384
397
// A mapping of FuncOps to their callers.
385
398
FuncCallerMap callerMap;
@@ -388,27 +401,33 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
388
401
return failure ();
389
402
390
403
// Analyze ops.
391
- for (func::FuncOp funcOp : orderedFuncOps) {
392
- if (!state.getOptions ().isOpAllowed (funcOp))
404
+ for (FunctionOpInterface funcOp : orderedFuncOps) {
405
+
406
+ // The following analysis is specific to the FuncOp type.
407
+ if (!isa<FuncOp>(funcOp))
408
+ continue ;
409
+ FuncOp funcOpCasted = cast<func::FuncOp>(funcOp);
410
+
411
+ if (!state.getOptions ().isOpAllowed (funcOpCasted))
393
412
continue ;
394
413
395
414
// Now analyzing function.
396
- funcState.startFunctionAnalysis (funcOp );
415
+ funcState.startFunctionAnalysis (funcOpCasted );
397
416
398
417
// Gather equivalence info for CallOps.
399
- equivalenceAnalysis (funcOp , state, funcState);
418
+ equivalenceAnalysis (funcOpCasted , state, funcState);
400
419
401
420
// Analyze funcOp.
402
- if (failed (analyzeOp (funcOp , state, statistics)))
421
+ if (failed (analyzeOp (funcOpCasted , state, statistics)))
403
422
return failure ();
404
423
405
424
// Run some extra function analyses.
406
- if (failed (aliasingFuncOpBBArgsAnalysis (funcOp , state, funcState)) ||
407
- failed (funcOpBbArgReadWriteAnalysis (funcOp , state, funcState)))
425
+ if (failed (aliasingFuncOpBBArgsAnalysis (funcOpCasted , state, funcState)) ||
426
+ failed (funcOpBbArgReadWriteAnalysis (funcOpCasted , state, funcState)))
408
427
return failure ();
409
428
410
429
// Mark op as fully analyzed.
411
- funcState.analyzedFuncOps [funcOp ] = FuncOpAnalysisState::Analyzed;
430
+ funcState.analyzedFuncOps [funcOpCasted ] = FuncOpAnalysisState::Analyzed;
412
431
}
413
432
414
433
return success ();
@@ -430,20 +449,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
430
449
IRRewriter rewriter (moduleOp.getContext ());
431
450
432
451
// A list of functions in the order in which they are analyzed + bufferized.
433
- SmallVector<func::FuncOp > orderedFuncOps;
452
+ SmallVector<FunctionOpInterface > orderedFuncOps;
434
453
435
454
// A mapping of FuncOps to their callers.
436
455
FuncCallerMap callerMap;
437
456
438
457
if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps, callerMap)))
439
458
return failure ();
459
+ SmallVector<FunctionOpInterface> ops;
440
460
441
461
// Bufferize functions.
442
- for (func::FuncOp funcOp : orderedFuncOps) {
462
+ for (FunctionOpInterface funcOp : orderedFuncOps) {
443
463
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
444
464
// would be invalidated.
445
-
446
- if (llvm::is_contained (options.noAnalysisFuncFilter , funcOp.getSymName ())) {
465
+ if (llvm::is_contained (options.noAnalysisFuncFilter , funcOp.getName ())) {
447
466
// This function was not analyzed and RaW conflicts were not resolved.
448
467
// Buffer copies must be inserted before every write.
449
468
OneShotBufferizationOptions updatedOptions = options;
@@ -456,8 +475,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
456
475
}
457
476
458
477
// Change buffer return types to more precise layout maps.
459
- if (options.inferFunctionResultLayout )
460
- foldMemRefCasts (funcOp);
478
+ if (options.inferFunctionResultLayout && isa<func::FuncOp>(funcOp) )
479
+ foldMemRefCasts (cast<func::FuncOp>( funcOp) );
461
480
}
462
481
463
482
// Bufferize all other ops.
0 commit comments