Skip to content

Commit 1f8d847

Browse files
committed
Make getAssumedUniqueReturnOp detect ReturnLike and FuncAnalysisState use FunctionOpInterface
1 parent 5153af3 commit 1f8d847

File tree

3 files changed

+59
-72
lines changed

3 files changed

+59
-72
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,24 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
5050

5151
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
5252
/// indices.
53-
DenseMap<FuncOp, IndexMapping> equivalentFuncArgs;
53+
DenseMap<FunctionOpInterface, IndexMapping> equivalentFuncArgs;
5454

5555
/// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
56-
DenseMap<FuncOp, IndexToIndexListMapping> aliasingReturnVals;
56+
DenseMap<FunctionOpInterface, IndexToIndexListMapping> aliasingReturnVals;
5757

5858
/// A set of all read BlockArguments of FuncOps.
59-
DenseMap<FuncOp, BbArgIndexSet> readBbArgs;
59+
DenseMap<FunctionOpInterface, BbArgIndexSet> readBbArgs;
6060

6161
/// A set of all written-to BlockArguments of FuncOps.
62-
DenseMap<FuncOp, BbArgIndexSet> writtenBbArgs;
62+
DenseMap<FunctionOpInterface, BbArgIndexSet> writtenBbArgs;
6363

6464
/// Keep track of which FuncOps are fully analyzed or currently being
6565
/// analyzed.
66-
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
66+
DenseMap<FunctionOpInterface, FuncOpAnalysisState> analyzedFuncOps;
6767

6868
/// This function is called right before analyzing the given FuncOp. It
6969
/// initializes the data structures for the FuncOp in this state object.
70-
void startFunctionAnalysis(FuncOp funcOp);
70+
void startFunctionAnalysis(FunctionOpInterface funcOp);
7171
};
7272

7373
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace mlir {
2222
namespace bufferization {
2323
namespace func_ext {
2424

25-
void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
25+
void FuncAnalysisState::startFunctionAnalysis(FunctionOpInterface funcOp) {
2626
analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
2727
auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
2828
auto createdAliasingResults =

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 52 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8888

8989
/// Return the unique ReturnOp that terminates `funcOp`.
9090
/// Return nullptr if there is no such unique ReturnOp.
91-
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
92-
func::ReturnOp returnOp;
93-
for (Block &b : funcOp.getBody()) {
94-
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
91+
static Operation* getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
92+
Operation *returnOp = nullptr;
93+
for (Block &b : funcOp.getFunctionBody()) {
94+
auto candidateOp = b.getTerminator();
95+
if (candidateOp && candidateOp->hasTrait<OpTrait::ReturnLike>()) {
9596
if (returnOp)
9697
return nullptr;
9798
returnOp = candidateOp;
@@ -126,16 +127,15 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
126127
/// Store function BlockArguments that are equivalent to/aliasing a returned
127128
/// value in FuncAnalysisState.
128129
static LogicalResult
129-
aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
130+
aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &state,
130131
FuncAnalysisState &funcState) {
131-
if (funcOp.getBody().empty()) {
132+
if (funcOp.getFunctionBody().empty()) {
132133
// No function body available. Conservatively assume that every tensor
133134
// return value may alias with any tensor bbArg.
134-
FunctionType type = funcOp.getFunctionType();
135-
for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
135+
for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) {
136136
if (!isa<TensorType>(inputIt.value()))
137137
continue;
138-
for (const auto &resultIt : llvm::enumerate(type.getResults())) {
138+
for (const auto &resultIt : llvm::enumerate(funcOp.getResultTypes())) {
139139
if (!isa<TensorType>(resultIt.value()))
140140
continue;
141141
int64_t returnIdx = resultIt.index();
@@ -147,7 +147,9 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
147147
}
148148

149149
// Support only single return-terminated block in the function.
150-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
150+
if (!isa<func::FuncOp>(funcOp))
151+
return success();
152+
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
151153
assert(returnOp && "expected func with single return op");
152154

153155
for (OpOperand &returnVal : returnOp->getOpOperands())
@@ -168,7 +170,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
168170
return success();
169171
}
170172

171-
static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
173+
static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx, bool isRead,
172174
bool isWritten) {
173175
OpBuilder b(funcOp.getContext());
174176
Attribute accessType;
@@ -189,12 +191,12 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
189191
/// function with unknown ops, we conservatively assume that such ops bufferize
190192
/// to a read + write.
191193
static LogicalResult
192-
funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
194+
funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &state,
193195
FuncAnalysisState &funcState) {
194-
for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
196+
for (int64_t idx = 0, e = funcOp.getNumArguments(); idx < e;
195197
++idx) {
196198
// Skip non-tensor arguments.
197-
if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
199+
if (!isa<TensorType>(funcOp.getArgumentTypes()[idx]))
198200
continue;
199201
bool isRead;
200202
bool isWritten;
@@ -204,7 +206,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
204206
StringRef str = accessAttr.getValue();
205207
isRead = str == "read" || str == "read-write";
206208
isWritten = str == "write" || str == "read-write";
207-
} else if (funcOp.getBody().empty()) {
209+
} else if (funcOp.getFunctionBody().empty()) {
208210
// If the function has no body, conservatively assume that all args are
209211
// read + written.
210212
isRead = true;
@@ -230,23 +232,13 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
230232

231233
/// Remove bufferization attributes on FuncOp arguments.
232234
static void removeBufferizationAttributes(BlockArgument bbArg) {
233-
auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
235+
auto funcOp = cast<FunctionOpInterface>(bbArg.getOwner()->getParentOp());
234236
funcOp.removeArgAttr(bbArg.getArgNumber(),
235237
BufferizationDialect::kBufferLayoutAttrName);
236238
funcOp.removeArgAttr(bbArg.getArgNumber(),
237239
BufferizationDialect::kWritableAttrName);
238240
}
239241

240-
/// Return the func::FuncOp called by `callOp`.
241-
static func::FuncOp getCalledFunction(func::CallOp callOp) {
242-
SymbolRefAttr sym =
243-
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
244-
if (!sym)
245-
return nullptr;
246-
return dyn_cast_or_null<func::FuncOp>(
247-
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
248-
}
249-
250242
static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
251243
SymbolRefAttr sym =
252244
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
@@ -260,12 +252,12 @@ static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
260252
/// Note: This only adds new equivalence info if the called function was already
261253
/// analyzed.
262254
// TODO: This does not handle cyclic function call graphs etc.
263-
static void equivalenceAnalysis(func::FuncOp funcOp,
255+
static void equivalenceAnalysis(FunctionOpInterface funcOp,
264256
OneShotAnalysisState &state,
265257
FuncAnalysisState &funcState) {
266-
funcOp->walk([&](func::CallOp callOp) {
267-
func::FuncOp calledFunction = getCalledFunction(callOp);
268-
assert(calledFunction && "could not retrieved called func::FuncOp");
258+
funcOp->walk([&](CallOpInterface callOp) {
259+
FunctionOpInterface calledFunction = getCalledFunction(callOp);
260+
assert(calledFunction && "could not retrieved called FunctionOpInterface");
269261

270262
// No equivalence info available for the called function.
271263
if (!funcState.equivalentFuncArgs.count(calledFunction))
@@ -276,7 +268,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
276268
int64_t bbargIdx = it.second;
277269
if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
278270
continue;
279-
Value returnVal = callOp.getResult(returnIdx);
271+
Value returnVal = callOp->getResult(returnIdx);
280272
Value argVal = callOp->getOperand(bbargIdx);
281273
state.unionEquivalenceClasses(returnVal, argVal);
282274
}
@@ -308,23 +300,19 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
308300
// For each FuncOp, the number of func::CallOp it contains.
309301
DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
310302
WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
311-
// Only handle ReturnOp if funcOp is exactly the FuncOp type.
312-
if(isa<FuncOp>(funcOp)) {
313-
FuncOp funcOpCasted = cast<FuncOp>(funcOp);
314-
if (!funcOpCasted.getBody().empty()) {
315-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOpCasted);
316-
if (!returnOp)
317-
return funcOp->emitError()
318-
<< "cannot bufferize a FuncOp with tensors and "
319-
"without a unique ReturnOp";
320-
}
303+
if (!funcOp.getFunctionBody().empty() && isa<func::FuncOp>(funcOp)) {
304+
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
305+
if (!returnOp)
306+
return funcOp->emitError()
307+
<< "cannot bufferize a FuncOp with tensors and "
308+
"without a unique ReturnOp";
321309
}
322310

323311
// Collect function calls and populate the caller map.
324312
numberCallOpsContainedInFuncOp[funcOp] = 0;
325313
return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
326314
FunctionOpInterface calledFunction = getCalledFunction(callOp);
327-
assert(calledFunction && "could not retrieved called func::FuncOp");
315+
assert(calledFunction && "could not retrieved called FunctionOpInterface");
328316
// If the called function does not have any tensors in its signature, then
329317
// it is not necessary to bufferize the callee before the caller.
330318
if (!hasTensorSignature(calledFunction))
@@ -362,11 +350,15 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
362350
/// most generic layout map as function return types. After bufferizing the
363351
/// entire function body, a more concise memref type can potentially be used for
364352
/// the return type of the function.
365-
static void foldMemRefCasts(func::FuncOp funcOp) {
366-
if (funcOp.getBody().empty())
353+
static void foldMemRefCasts(FunctionOpInterface funcOp) {
354+
if (funcOp.getFunctionBody().empty())
355+
return;
356+
357+
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
358+
359+
if (!returnOp)
367360
return;
368361

369-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
370362
SmallVector<Type> resultTypes;
371363

372364
for (OpOperand &operand : returnOp->getOpOperands()) {
@@ -379,7 +371,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
379371
}
380372

381373
auto newFuncType = FunctionType::get(
382-
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
374+
funcOp.getContext(), funcOp.getArgumentTypes(), resultTypes);
383375
funcOp.setType(newFuncType);
384376
}
385377

@@ -403,39 +395,34 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
403395
// Analyze ops.
404396
for (FunctionOpInterface funcOp : orderedFuncOps) {
405397

406-
// The following analysis is specific to the FuncOp type.
407-
if(!isa<FuncOp>(funcOp))
408-
continue;
409-
FuncOp funcOpCasted = cast<func::FuncOp>(funcOp);
410-
411-
if (!state.getOptions().isOpAllowed(funcOpCasted))
398+
if (!state.getOptions().isOpAllowed(funcOp))
412399
continue;
413400

414401
// Now analyzing function.
415-
funcState.startFunctionAnalysis(funcOpCasted);
402+
funcState.startFunctionAnalysis(funcOp);
416403

417404
// Gather equivalence info for CallOps.
418-
equivalenceAnalysis(funcOpCasted, state, funcState);
405+
equivalenceAnalysis(funcOp, state, funcState);
419406

420407
// Analyze funcOp.
421-
if (failed(analyzeOp(funcOpCasted, state, statistics)))
408+
if (failed(analyzeOp(funcOp, state, statistics)))
422409
return failure();
423410

424411
// Run some extra function analyses.
425-
if (failed(aliasingFuncOpBBArgsAnalysis(funcOpCasted, state, funcState)) ||
426-
failed(funcOpBbArgReadWriteAnalysis(funcOpCasted, state, funcState)))
412+
if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
413+
failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
427414
return failure();
428415

429416
// Mark op as fully analyzed.
430-
funcState.analyzedFuncOps[funcOpCasted] = FuncOpAnalysisState::Analyzed;
417+
funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
431418
}
432419

433420
return success();
434421
}
435422

436423
void mlir::bufferization::removeBufferizationAttributesInModule(
437424
ModuleOp moduleOp) {
438-
moduleOp.walk([&](func::FuncOp op) {
425+
moduleOp.walk([&](FunctionOpInterface op) {
439426
for (BlockArgument bbArg : op.getArguments())
440427
removeBufferizationAttributes(bbArg);
441428
});
@@ -475,14 +462,14 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
475462
}
476463

477464
// Change buffer return types to more precise layout maps.
478-
if (options.inferFunctionResultLayout && isa<func::FuncOp>(funcOp))
479-
foldMemRefCasts(cast<func::FuncOp>(funcOp));
465+
if (options.inferFunctionResultLayout)
466+
foldMemRefCasts(funcOp);
480467
}
481468

482469
// Bufferize all other ops.
483470
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
484471
// Functions were already bufferized.
485-
if (isa<func::FuncOp>(&op))
472+
if (isa<FunctionOpInterface>(&op))
486473
continue;
487474
if (failed(bufferizeOp(&op, options, statistics)))
488475
return failure();
@@ -509,12 +496,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
509496
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
510497
// not be analyzed. Ops in these FuncOps will not be analyzed as well.
511498
OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
512-
auto func = dyn_cast<func::FuncOp>(op);
499+
auto func = dyn_cast<FunctionOpInterface>(op);
513500
if (!func)
514-
func = op->getParentOfType<func::FuncOp>();
501+
func = op->getParentOfType<FunctionOpInterface>();
515502
if (func)
516503
return llvm::is_contained(options.noAnalysisFuncFilter,
517-
func.getSymName());
504+
func.getName());
518505
return false;
519506
};
520507
OneShotBufferizationOptions updatedOptions(options);

0 commit comments

Comments
 (0)