Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class FuncOp;

namespace bufferization {
/// Helper function that returns all func.return ops in the given function.
SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp);
SmallVector<Operation *> getReturnOps(func::FuncOp funcOp);

namespace func_ext {
/// The state of analysis of a FuncOp.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@

namespace mlir {
/// Return all func.return ops in the given function.
SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
SmallVector<func::ReturnOp> result;
for (Block &b : funcOp.getBody())
if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
result.push_back(returnOp);
SmallVector<Operation *> bufferization::getReturnOps(func::FuncOp funcOp) {
SmallVector<Operation *> result;
for (Block &b : funcOp.getBody()) {
Operation *terminator = b.getTerminator();
if (terminator->hasTrait<OpTrait::ReturnLike>())
result.push_back(b.getTerminator());
}
return result;
}

Expand Down Expand Up @@ -439,7 +441,7 @@ struct FuncOpInterface
return failure();

// 2. Bufferize the operands of the all return op.
for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
for (Operation *returnOp : getReturnOps(funcOp)) {
assert(returnOp->getNumOperands() == retTypes.size() &&
"incorrect number of return values");
SmallVector<Value> returnValues;
Expand All @@ -457,11 +459,13 @@ struct FuncOpInterface
// Note: If `inferFunctionResultLayout = true`, casts are later folded
// away.
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
returnOp.getLoc(), bufferizedType, returnVal);
returnOp->getLoc(), bufferizedType, returnVal);
returnValues.push_back(toMemrefOp);
}

returnOp.getOperandsMutable().assign(returnValues);
for (auto [i, operand] : enumerate(returnValues)) {
returnOp->setOperand(i, operand);
}
}

// 3. Set the new function type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
}

// Find all func.return ops.
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
SmallVector<Operation *> returnOps = getReturnOps(funcOp);
assert(!returnOps.empty() && "expected at least one ReturnOp");

// Build alias sets. Merge all aliases from all func.return ops.
Expand All @@ -142,7 +142,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
int64_t bbArgIdx = bbArg.getArgNumber();
// Store aliases in a set, so that we don't add the same alias twice.
SetVector<int64_t> aliases;
for (func::ReturnOp returnOp : returnOps) {
for (Operation *returnOp : returnOps) {
for (OpOperand &returnVal : returnOp->getOpOperands()) {
if (isa<RankedTensorType>(returnVal.get().getType())) {
int64_t returnIdx = returnVal.getOperandNumber();
Expand Down Expand Up @@ -192,7 +192,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
// argument for the i-th operand. In contrast to aliasing information,
// which is just "merged", equivalence information must match across all
// func.return ops.
for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
for (Operation *returnOp : ArrayRef(returnOps).drop_front()) {
std::optional<int64_t> maybeEquiv =
findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
if (maybeEquiv != bbArgIdx) {
Expand Down Expand Up @@ -398,7 +398,7 @@ static Value unpackCast(Value v) {
/// func.return ops. This function returns as many types as the return ops have
/// operands. If the i-th operand is not the same for all func.return ops, then
/// the i-th returned type is an "empty" type.
static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
static SmallVector<Type> getReturnTypes(SmallVector<Operation *> returnOps) {
assert(!returnOps.empty() && "expected at least one ReturnOp");
int numOperands = returnOps.front()->getNumOperands();

Expand Down Expand Up @@ -434,11 +434,11 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
return;

// Compute the common result types of all return ops.
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
SmallVector<Operation *> returnOps = getReturnOps(funcOp);
SmallVector<Type> resultTypes = getReturnTypes(returnOps);

// Remove direct casts.
for (func::ReturnOp returnOp : returnOps) {
for (Operation *returnOp : returnOps) {
for (OpOperand &operand : returnOp->getOpOperands()) {
// Bail if no common result type was found.
if (resultTypes[operand.getOperandNumber()]) {
Expand Down