Skip to content

[MLIR] Make OneShotModuleBufferize use OpInterface #107295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ using namespace mlir::bufferization;
using namespace mlir::bufferization::func_ext;

/// A mapping of FuncOps to their callers.
using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
using FuncCallerMap = DenseMap<FunctionOpInterface, DenseSet<Operation *>>;

/// Get or create FuncAnalysisState.
static FuncAnalysisState &
Expand Down Expand Up @@ -247,6 +247,15 @@ static func::FuncOp getCalledFunction(func::CallOp callOp) {
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}

static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FunctionOpInterface>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}

/// Gather equivalence info of CallOps.
/// Note: This only adds new equivalence info if the called function was already
/// analyzed.
Expand Down Expand Up @@ -277,10 +286,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
}

/// Return "true" if the given function signature has tensor semantics.
static bool hasTensorSignature(func::FuncOp funcOp) {
return llvm::any_of(funcOp.getFunctionType().getInputs(),
static bool hasTensorSignature(FunctionOpInterface funcOp) {
return llvm::any_of(funcOp.getArgumentTypes(),
llvm::IsaPred<TensorType>) ||
llvm::any_of(funcOp.getFunctionType().getResults(),
llvm::any_of(funcOp.getResultTypes(),
llvm::IsaPred<TensorType>);
}

Expand All @@ -291,26 +300,30 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// retrieve the called FuncOp from any func::CallOp.
static LogicalResult
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<FunctionOpInterface> &orderedFuncOps,
FuncCallerMap &callerMap) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
DenseMap<FunctionOpInterface, DenseSet<FunctionOpInterface>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
if (!funcOp.getBody().empty()) {
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
if (!returnOp)
return funcOp->emitError()
<< "cannot bufferize a FuncOp with tensors and "
"without a unique ReturnOp";
DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
// Only handle ReturnOp if funcOp is exactly the FuncOp type.
if(isa<FuncOp>(funcOp)) {
FuncOp funcOpCasted = cast<FuncOp>(funcOp);
if (!funcOpCasted.getBody().empty()) {
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOpCasted);
if (!returnOp)
return funcOp->emitError()
<< "cannot bufferize a FuncOp with tensors and "
"without a unique ReturnOp";
}
}

// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp);
return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
FunctionOpInterface calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
Expand Down Expand Up @@ -379,7 +392,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);

// A list of functions in the order in which they are analyzed + bufferized.
SmallVector<func::FuncOp> orderedFuncOps;
SmallVector<FunctionOpInterface> orderedFuncOps;

// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
Expand All @@ -388,27 +401,33 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
return failure();

// Analyze ops.
for (func::FuncOp funcOp : orderedFuncOps) {
if (!state.getOptions().isOpAllowed(funcOp))
for (FunctionOpInterface funcOp : orderedFuncOps) {

// The following analysis is specific to the FuncOp type.
if(!isa<FuncOp>(funcOp))
continue;
FuncOp funcOpCasted = cast<func::FuncOp>(funcOp);

if (!state.getOptions().isOpAllowed(funcOpCasted))
continue;

// Now analyzing function.
funcState.startFunctionAnalysis(funcOp);
funcState.startFunctionAnalysis(funcOpCasted);

// Gather equivalence info for CallOps.
equivalenceAnalysis(funcOp, state, funcState);
equivalenceAnalysis(funcOpCasted, state, funcState);

// Analyze funcOp.
if (failed(analyzeOp(funcOp, state, statistics)))
if (failed(analyzeOp(funcOpCasted, state, statistics)))
return failure();

// Run some extra function analyses.
if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
if (failed(aliasingFuncOpBBArgsAnalysis(funcOpCasted, state, funcState)) ||
failed(funcOpBbArgReadWriteAnalysis(funcOpCasted, state, funcState)))
return failure();

// Mark op as fully analyzed.
funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
funcState.analyzedFuncOps[funcOpCasted] = FuncOpAnalysisState::Analyzed;
}

return success();
Expand All @@ -430,7 +449,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
IRRewriter rewriter(moduleOp.getContext());

// A list of functions in the order in which they are analyzed + bufferized.
SmallVector<func::FuncOp> orderedFuncOps;
SmallVector<FunctionOpInterface> orderedFuncOps;

// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
Expand All @@ -439,11 +458,11 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
return failure();

// Bufferize functions.
for (func::FuncOp funcOp : orderedFuncOps) {
for (FunctionOpInterface funcOp : orderedFuncOps) {
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.

if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
// This function was not analyzed and RaW conflicts were not resolved.
// Buffer copies must be inserted before every write.
OneShotBufferizationOptions updatedOptions = options;
Expand All @@ -456,8 +475,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}

// Change buffer return types to more precise layout maps.
if (options.inferFunctionResultLayout)
foldMemRefCasts(funcOp);
if (options.inferFunctionResultLayout && isa<func::FuncOp>(funcOp))
foldMemRefCasts(cast<func::FuncOp>(funcOp));
}

// Bufferize all other ops.
Expand Down
Loading