Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion mlir/include/mlir/Interfaces/ViewLikeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,17 @@ def ViewLikeOpInterface : OpInterface<"ViewLikeOpInterface"> {
let methods = [
InterfaceMethod<
"Returns the source buffer from which the view is created.",
"::mlir::Value", "getViewSource">
"::mlir::Value", "getViewSource">,
InterfaceMethod<
/*desc=*/[{ Returns the buffer which the view created. }],
/*retTy=*/"::mlir::Value",
/*methodName=*/"getViewDest",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op->getResult(0);
}]
>
];
}

Expand Down
9 changes: 6 additions & 3 deletions mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,12 @@ static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth,
Operation *op = result.getOwner();

// If this is a view, unwrap to the source.
if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op))
return collectUnderlyingAddressValues(view.getViewSource(), maxDepth,
visited, output);
if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op)) {
if (result == view.getViewDest()) {
return collectUnderlyingAddressValues(view.getViewSource(), maxDepth,
visited, output);
}
}
// Check to see if we can reason about the control flow of this op.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
return collectUnderlyingAddressValues(branch, /*region=*/nullptr, result,
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ static bool isEscapingMemref(Value memref, Block *block) {

// Check if this is defined to be an alias of another memref.
if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
if (isEscapingMemref(viewOp.getViewSource(), block))
if (memref == viewOp.getViewDest() &&
isEscapingMemref(viewOp.getViewSource(), block))
return true;

// Any op besides allocating ops wouldn't guarantee alias freedom
Expand Down
6 changes: 5 additions & 1 deletion mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,12 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> {
// which otherwise could prevent removal of unnecessary allocs.
Value canonicalSource = source;
while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
canonicalSource.getDefiningOp()))
canonicalSource.getDefiningOp())) {
if (canonicalSource != iface.getViewDest()) {
break;
}
canonicalSource = iface.getViewSource();
}

std::optional<Operation *> maybeCloneDeallocOp =
memref::findDealloc(cloneOp.getOutput());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ using namespace mlir::bufferization;
/// Given a memref value, return the "base" value by skipping over all
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
static Value getViewBase(Value value) {
while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
if (value != viewLikeOp.getViewDest()) {
break;
}
value = viewLikeOp.getViewSource();
}
return value;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Add additional dependencies created by view changes to the alias list.
if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
registerDependencies(viewInterface.getViewSource(),
viewInterface->getResult(0));
viewInterface.getViewDest());
return WalkResult::advance();
}

Expand Down Expand Up @@ -231,8 +231,12 @@ static bool isFunctionArgument(Value v) {
/// Given a memref value, return the "base" value by skipping over all
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
static Value getViewBase(Value value) {
while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
if (value != viewLikeOp.getViewDest()) {
break;
}
value = viewLikeOp.getViewSource();
}
return value;
}

Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,10 @@ getUnderlyingObjectSet(Value pointerValue) {
WalkContinuation walkResult = walkSlice(pointerValue, [&](Value val) {
// Attempt to advance to the source of the underlying view-like operation.
// Examples of view-like operations include GEPOp and AddrSpaceCastOp.
if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>())
return WalkContinuation::advanceTo(viewOp.getViewSource());
if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>()) {
if (val == viewOp.getViewDest())
return WalkContinuation::advanceTo(viewOp.getViewSource());
}

// Attempt to advance to control flow predecessors.
std::optional<SmallVector<Value>> controlFlowPredecessors =
Expand Down
11 changes: 8 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,12 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
Value source = transferRead.getBase();

// Skip view-like Ops and retrive the actual soruce Operation
while (auto srcOp = source.getDefiningOp<ViewLikeOpInterface>())
source = srcOp.getViewSource();
while (auto viewLike = source.getDefiningOp<ViewLikeOpInterface>()) {
if (viewLike.getViewDest() != source) {
break;
}
source = viewLike.getViewSource();
}

llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
Expand All @@ -178,7 +182,8 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
if (!processed.insert(user).second)
continue;
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
Value viewDest = viewLike.getViewDest();
users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
PatternRewriter &rewriter) const override {
auto viewLikeOp =
extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
if (!viewLikeOp)
if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest())
return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
rewriter.modifyOpInPlace(extractOp, [&]() {
extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,10 @@ MemrefValue skipFullyAliasingOperations(MemrefValue source) {
MemrefValue skipViewLikeOps(MemrefValue source) {
while (auto op = source.getDefiningOp()) {
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
source = cast<MemrefValue>(viewLike.getViewSource());
continue;
if (source == viewLike.getViewDest()) {
source = cast<MemrefValue>(viewLike.getViewSource());
continue;
}
}
return source;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
if (isa<ViewLikeOpInterface>(user)) {
users.append(user->getUsers().begin(), user->getUsers().end());
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
Value viewDest = viewLike.getViewDest();
users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user))
Expand Down Expand Up @@ -182,8 +183,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
if (isa<ViewLikeOpInterface>(user)) {
users.append(user->getUsers().begin(), user->getUsers().end());
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
Value viewDest = viewLike.getViewDest();
users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
Expand Down