Skip to content

Revert "[MLIR] Make OneShotModuleBufferize use OpInterface" #109919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfoVariant.h"
#include "llvm/ADT/SetVector.h"
Expand Down Expand Up @@ -261,9 +260,9 @@ struct BufferizationOptions {
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, func op, bufferization options
using FunctionArgTypeConverterFn = std::function<BaseMemRefType(
TensorType, Attribute memorySpace, FunctionOpInterface,
const BufferizationOptions &)>;
using FunctionArgTypeConverterFn =
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,24 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {

/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
/// indices.
DenseMap<FunctionOpInterface, IndexMapping> equivalentFuncArgs;
DenseMap<FuncOp, IndexMapping> equivalentFuncArgs;

/// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap<FunctionOpInterface, IndexToIndexListMapping> aliasingReturnVals;
DenseMap<FuncOp, IndexToIndexListMapping> aliasingReturnVals;

/// A set of all read BlockArguments of FuncOps.
DenseMap<FunctionOpInterface, BbArgIndexSet> readBbArgs;
DenseMap<FuncOp, BbArgIndexSet> readBbArgs;

/// A set of all written-to BlockArguments of FuncOps.
DenseMap<FunctionOpInterface, BbArgIndexSet> writtenBbArgs;
DenseMap<FuncOp, BbArgIndexSet> writtenBbArgs;

/// Keep track of which FuncOps are fully analyzed or currently being
/// analyzed.
DenseMap<FunctionOpInterface, FuncOpAnalysisState> analyzedFuncOps;
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;

/// This function is called right before analyzing the given FuncOp. It
/// initializes the data structures for the FuncOp in this state object.
void startFunctionAnalysis(FunctionOpInterface funcOp);
void startFunctionAnalysis(FuncOp funcOp);
};

void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -315,7 +314,7 @@ namespace {
/// Default function arg type converter: Use a fully dynamic layout map.
BaseMemRefType
defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
FunctionOpInterface funcOp,
func::FuncOp funcOp,
const BufferizationOptions &options) {
return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
}
Expand Down Expand Up @@ -362,7 +361,7 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
void BufferizationOptions::setFunctionBoundaryTypeConversion(
LayoutMapOption layoutMapOption) {
functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
FunctionOpInterface funcOp,
func::FuncOp funcOp,
const BufferizationOptions &options) {
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace mlir {
namespace bufferization {
namespace func_ext {

void FuncAnalysisState::startFunctionAnalysis(FunctionOpInterface funcOp) {
void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
auto createdAliasingResults =
Expand Down
111 changes: 56 additions & 55 deletions mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ using namespace mlir::bufferization;
using namespace mlir::bufferization::func_ext;

/// A mapping of FuncOps to their callers.
using FuncCallerMap = DenseMap<FunctionOpInterface, DenseSet<Operation *>>;
using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;

/// Get or create FuncAnalysisState.
static FuncAnalysisState &
Expand All @@ -88,11 +88,10 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {

/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
static Operation *getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
Operation *returnOp = nullptr;
for (Block &b : funcOp.getFunctionBody()) {
auto candidateOp = b.getTerminator();
if (candidateOp && candidateOp->hasTrait<OpTrait::ReturnLike>()) {
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
func::ReturnOp returnOp;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
if (returnOp)
return nullptr;
returnOp = candidateOp;
Expand Down Expand Up @@ -127,16 +126,16 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
/// Store function BlockArguments that are equivalent to/aliasing a returned
/// value in FuncAnalysisState.
static LogicalResult
aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
OneShotAnalysisState &state,
aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
if (funcOp.getFunctionBody().empty()) {
if (funcOp.getBody().empty()) {
// No function body available. Conservatively assume that every tensor
// return value may alias with any tensor bbArg.
for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) {
FunctionType type = funcOp.getFunctionType();
for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
if (!isa<TensorType>(inputIt.value()))
continue;
for (const auto &resultIt : llvm::enumerate(funcOp.getResultTypes())) {
for (const auto &resultIt : llvm::enumerate(type.getResults())) {
if (!isa<TensorType>(resultIt.value()))
continue;
int64_t returnIdx = resultIt.index();
Expand All @@ -148,7 +147,7 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
}

// Support only single return-terminated block in the function.
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");

for (OpOperand &returnVal : returnOp->getOpOperands())
Expand All @@ -169,8 +168,8 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
return success();
}

static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx,
bool isRead, bool isWritten) {
static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
bool isWritten) {
OpBuilder b(funcOp.getContext());
Attribute accessType;
if (isRead && isWritten) {
Expand All @@ -190,12 +189,12 @@ static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx,
/// function with unknown ops, we conservatively assume that such ops bufferize
/// to a read + write.
static LogicalResult
funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
OneShotAnalysisState &state,
funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
for (int64_t idx = 0, e = funcOp.getNumArguments(); idx < e; ++idx) {
for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
++idx) {
// Skip non-tensor arguments.
if (!isa<TensorType>(funcOp.getArgumentTypes()[idx]))
if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
continue;
bool isRead;
bool isWritten;
Expand All @@ -205,7 +204,7 @@ funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
StringRef str = accessAttr.getValue();
isRead = str == "read" || str == "read-write";
isWritten = str == "write" || str == "read-write";
} else if (funcOp.getFunctionBody().empty()) {
} else if (funcOp.getBody().empty()) {
// If the function has no body, conservatively assume that all args are
// read + written.
isRead = true;
Expand All @@ -231,32 +230,33 @@ funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,

/// Remove bufferization attributes on FuncOp arguments.
static void removeBufferizationAttributes(BlockArgument bbArg) {
auto funcOp = cast<FunctionOpInterface>(bbArg.getOwner()->getParentOp());
auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
funcOp.removeArgAttr(bbArg.getArgNumber(),
BufferizationDialect::kBufferLayoutAttrName);
funcOp.removeArgAttr(bbArg.getArgNumber(),
BufferizationDialect::kWritableAttrName);
}

static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
/// Return the func::FuncOp called by `callOp`.
static func::FuncOp getCalledFunction(func::CallOp callOp) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FunctionOpInterface>(
return dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}

/// Gather equivalence info of CallOps.
/// Note: This only adds new equivalence info if the called function was already
/// analyzed.
// TODO: This does not handle cyclic function call graphs etc.
static void equivalenceAnalysis(FunctionOpInterface funcOp,
static void equivalenceAnalysis(func::FuncOp funcOp,
OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
funcOp->walk([&](CallOpInterface callOp) {
FunctionOpInterface calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called FunctionOpInterface");
funcOp->walk([&](func::CallOp callOp) {
func::FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");

// No equivalence info available for the called function.
if (!funcState.equivalentFuncArgs.count(calledFunction))
Expand All @@ -267,7 +267,7 @@ static void equivalenceAnalysis(FunctionOpInterface funcOp,
int64_t bbargIdx = it.second;
if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
continue;
Value returnVal = callOp->getResult(returnIdx);
Value returnVal = callOp.getResult(returnIdx);
Value argVal = callOp->getOperand(bbargIdx);
state.unionEquivalenceClasses(returnVal, argVal);
}
Expand All @@ -277,9 +277,11 @@ static void equivalenceAnalysis(FunctionOpInterface funcOp,
}

/// Return "true" if the given function signature has tensor semantics.
static bool hasTensorSignature(FunctionOpInterface funcOp) {
return llvm::any_of(funcOp.getArgumentTypes(), llvm::IsaPred<TensorType>) ||
llvm::any_of(funcOp.getResultTypes(), llvm::IsaPred<TensorType>);
static bool hasTensorSignature(func::FuncOp funcOp) {
return llvm::any_of(funcOp.getFunctionType().getInputs(),
llvm::IsaPred<TensorType>) ||
llvm::any_of(funcOp.getFunctionType().getResults(),
llvm::IsaPred<TensorType>);
}

/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
Expand All @@ -289,16 +291,16 @@ static bool hasTensorSignature(FunctionOpInterface funcOp) {
/// retrieve the called FuncOp from any func::CallOp.
static LogicalResult
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
SmallVectorImpl<FunctionOpInterface> &orderedFuncOps,
SmallVectorImpl<func::FuncOp> &orderedFuncOps,
FuncCallerMap &callerMap) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
DenseMap<FunctionOpInterface, DenseSet<FunctionOpInterface>> calledBy;
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
if (!funcOp.getFunctionBody().empty()) {
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
if (!funcOp.getBody().empty()) {
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
if (!returnOp)
return funcOp->emitError()
<< "cannot bufferize a FuncOp with tensors and "
Expand All @@ -307,10 +309,9 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,

// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
FunctionOpInterface calledFunction = getCalledFunction(callOp);
assert(calledFunction &&
"could not retrieved called FunctionOpInterface");
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
if (!hasTensorSignature(calledFunction))
Expand Down Expand Up @@ -348,11 +349,11 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
/// most generic layout map as function return types. After bufferizing the
/// entire function body, a more concise memref type can potentially be used for
/// the return type of the function.
static void foldMemRefCasts(FunctionOpInterface funcOp) {
if (funcOp.getFunctionBody().empty())
static void foldMemRefCasts(func::FuncOp funcOp) {
if (funcOp.getBody().empty())
return;

Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
SmallVector<Type> resultTypes;

for (OpOperand &operand : returnOp->getOpOperands()) {
Expand All @@ -364,8 +365,8 @@ static void foldMemRefCasts(FunctionOpInterface funcOp) {
}
}

auto newFuncType = FunctionType::get(funcOp.getContext(),
funcOp.getArgumentTypes(), resultTypes);
auto newFuncType = FunctionType::get(
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
funcOp.setType(newFuncType);
}

Expand All @@ -378,7 +379,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);

// A list of functions in the order in which they are analyzed + bufferized.
SmallVector<FunctionOpInterface> orderedFuncOps;
SmallVector<func::FuncOp> orderedFuncOps;

// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
Expand All @@ -387,7 +388,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
return failure();

// Analyze ops.
for (FunctionOpInterface funcOp : orderedFuncOps) {
for (func::FuncOp funcOp : orderedFuncOps) {
if (!state.getOptions().isOpAllowed(funcOp))
continue;

Expand Down Expand Up @@ -415,7 +416,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,

void mlir::bufferization::removeBufferizationAttributesInModule(
ModuleOp moduleOp) {
moduleOp.walk([&](FunctionOpInterface op) {
moduleOp.walk([&](func::FuncOp op) {
for (BlockArgument bbArg : op.getArguments())
removeBufferizationAttributes(bbArg);
});
Expand All @@ -429,7 +430,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
IRRewriter rewriter(moduleOp.getContext());

// A list of functions in the order in which they are analyzed + bufferized.
SmallVector<FunctionOpInterface> orderedFuncOps;
SmallVector<func::FuncOp> orderedFuncOps;

// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
Expand All @@ -438,11 +439,11 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
return failure();

// Bufferize functions.
for (FunctionOpInterface funcOp : orderedFuncOps) {
for (func::FuncOp funcOp : orderedFuncOps) {
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.

if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
// This function was not analyzed and RaW conflicts were not resolved.
// Buffer copies must be inserted before every write.
OneShotBufferizationOptions updatedOptions = options;
Expand All @@ -462,7 +463,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// Bufferize all other ops.
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
// Functions were already bufferized.
if (isa<FunctionOpInterface>(&op))
if (isa<func::FuncOp>(&op))
continue;
if (failed(bufferizeOp(&op, options, statistics)))
return failure();
Expand All @@ -489,12 +490,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
// not be analyzed. Ops in these FuncOps will not be analyzed as well.
OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
auto func = dyn_cast<FunctionOpInterface>(op);
auto func = dyn_cast<func::FuncOp>(op);
if (!func)
func = op->getParentOfType<FunctionOpInterface>();
func = op->getParentOfType<func::FuncOp>();
if (func)
return llvm::is_contained(options.noAnalysisFuncFilter,
func.getName());
func.getSymName());
return false;
};
OneShotBufferizationOptions updatedOptions(options);
Expand Down
Loading
Loading