Skip to content

Commit 8a5aca2

Browse files
committed
Make OneShotModuleBufferize accept FunctionOpInterface and CallOpInterface
1 parent 660e34f commit 8a5aca2

File tree

1 file changed

+50
-31
lines changed

1 file changed

+50
-31
lines changed

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
7575
using namespace mlir::bufferization::func_ext;
7676

7777
/// A mapping of FuncOps to their callers.
78-
using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
78+
using FuncCallerMap = DenseMap<FunctionOpInterface, DenseSet<Operation *>>;
7979

8080
/// Get or create FuncAnalysisState.
8181
static FuncAnalysisState &
@@ -247,6 +247,15 @@ static func::FuncOp getCalledFunction(func::CallOp callOp) {
247247
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
248248
}
249249

250+
static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
251+
SymbolRefAttr sym =
252+
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
253+
if (!sym)
254+
return nullptr;
255+
return dyn_cast_or_null<FunctionOpInterface>(
256+
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
257+
}
258+
250259
/// Gather equivalence info of CallOps.
251260
/// Note: This only adds new equivalence info if the called function was already
252261
/// analyzed.
@@ -277,10 +286,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
277286
}
278287

279288
/// Return "true" if the given function signature has tensor semantics.
280-
static bool hasTensorSignature(func::FuncOp funcOp) {
281-
return llvm::any_of(funcOp.getFunctionType().getInputs(),
289+
static bool hasTensorSignature(FunctionOpInterface funcOp) {
290+
return llvm::any_of(funcOp.getArgumentTypes(),
282291
llvm::IsaPred<TensorType>) ||
283-
llvm::any_of(funcOp.getFunctionType().getResults(),
292+
llvm::any_of(funcOp.getResultTypes(),
284293
llvm::IsaPred<TensorType>);
285294
}
286295

@@ -291,26 +300,30 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
291300
/// retrieve the called FuncOp from any func::CallOp.
292301
static LogicalResult
293302
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
294-
SmallVectorImpl<func::FuncOp> &orderedFuncOps,
303+
SmallVectorImpl<FunctionOpInterface> &orderedFuncOps,
295304
FuncCallerMap &callerMap) {
296305
// For each FuncOp, the set of functions called by it (i.e. the union of
297306
// symbols of all nested func::CallOp).
298-
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
307+
DenseMap<FunctionOpInterface, DenseSet<FunctionOpInterface>> calledBy;
299308
// For each FuncOp, the number of func::CallOp it contains.
300-
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
301-
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
302-
if (!funcOp.getBody().empty()) {
303-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
304-
if (!returnOp)
305-
return funcOp->emitError()
306-
<< "cannot bufferize a FuncOp with tensors and "
307-
"without a unique ReturnOp";
309+
DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
310+
WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
311+
// Only handle ReturnOp if funcOp is exactly the FuncOp type.
312+
if(isa<FuncOp>(funcOp)) {
313+
FuncOp funcOpCasted = cast<FuncOp>(funcOp);
314+
if (!funcOpCasted.getBody().empty()) {
315+
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOpCasted);
316+
if (!returnOp)
317+
return funcOp->emitError()
318+
<< "cannot bufferize a FuncOp with tensors and "
319+
"without a unique ReturnOp";
320+
}
308321
}
309322

310323
// Collect function calls and populate the caller map.
311324
numberCallOpsContainedInFuncOp[funcOp] = 0;
312-
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
313-
func::FuncOp calledFunction = getCalledFunction(callOp);
325+
return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
326+
FunctionOpInterface calledFunction = getCalledFunction(callOp);
314327
assert(calledFunction && "could not retrieved called func::FuncOp");
315328
// If the called function does not have any tensors in its signature, then
316329
// it is not necessary to bufferize the callee before the caller.
@@ -379,7 +392,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
379392
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
380393

381394
// A list of functions in the order in which they are analyzed + bufferized.
382-
SmallVector<func::FuncOp> orderedFuncOps;
395+
SmallVector<FunctionOpInterface> orderedFuncOps;
383396

384397
// A mapping of FuncOps to their callers.
385398
FuncCallerMap callerMap;
@@ -388,27 +401,33 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
388401
return failure();
389402

390403
// Analyze ops.
391-
for (func::FuncOp funcOp : orderedFuncOps) {
392-
if (!state.getOptions().isOpAllowed(funcOp))
404+
for (FunctionOpInterface funcOp : orderedFuncOps) {
405+
406+
// The following analysis is specific to the FuncOp type.
407+
if(!isa<FuncOp>(funcOp))
408+
continue;
409+
FuncOp funcOpCasted = cast<func::FuncOp>(funcOp);
410+
411+
if (!state.getOptions().isOpAllowed(funcOpCasted))
393412
continue;
394413

395414
// Now analyzing function.
396-
funcState.startFunctionAnalysis(funcOp);
415+
funcState.startFunctionAnalysis(funcOpCasted);
397416

398417
// Gather equivalence info for CallOps.
399-
equivalenceAnalysis(funcOp, state, funcState);
418+
equivalenceAnalysis(funcOpCasted, state, funcState);
400419

401420
// Analyze funcOp.
402-
if (failed(analyzeOp(funcOp, state, statistics)))
421+
if (failed(analyzeOp(funcOpCasted, state, statistics)))
403422
return failure();
404423

405424
// Run some extra function analyses.
406-
if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
407-
failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
425+
if (failed(aliasingFuncOpBBArgsAnalysis(funcOpCasted, state, funcState)) ||
426+
failed(funcOpBbArgReadWriteAnalysis(funcOpCasted, state, funcState)))
408427
return failure();
409428

410429
// Mark op as fully analyzed.
411-
funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
430+
funcState.analyzedFuncOps[funcOpCasted] = FuncOpAnalysisState::Analyzed;
412431
}
413432

414433
return success();
@@ -430,20 +449,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
430449
IRRewriter rewriter(moduleOp.getContext());
431450

432451
// A list of functions in the order in which they are analyzed + bufferized.
433-
SmallVector<func::FuncOp> orderedFuncOps;
452+
SmallVector<FunctionOpInterface> orderedFuncOps;
434453

435454
// A mapping of FuncOps to their callers.
436455
FuncCallerMap callerMap;
437456

438457
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
439458
return failure();
459+
SmallVector<FunctionOpInterface> ops;
440460

441461
// Bufferize functions.
442-
for (func::FuncOp funcOp : orderedFuncOps) {
462+
for (FunctionOpInterface funcOp : orderedFuncOps) {
443463
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
444464
// would be invalidated.
445-
446-
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
465+
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
447466
// This function was not analyzed and RaW conflicts were not resolved.
448467
// Buffer copies must be inserted before every write.
449468
OneShotBufferizationOptions updatedOptions = options;
@@ -456,8 +475,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
456475
}
457476

458477
// Change buffer return types to more precise layout maps.
459-
if (options.inferFunctionResultLayout)
460-
foldMemRefCasts(funcOp);
478+
if (options.inferFunctionResultLayout && isa<func::FuncOp>(funcOp))
479+
foldMemRefCasts(cast<func::FuncOp>(funcOp));
461480
}
462481

463482
// Bufferize all other ops.

0 commit comments

Comments
 (0)