diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td index d1401c238381e..ed213bfdae337 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -26,7 +26,17 @@ def ViewLikeOpInterface : OpInterface<"ViewLikeOpInterface"> { let methods = [ InterfaceMethod< "Returns the source buffer from which the view is created.", - "::mlir::Value", "getViewSource"> + "::mlir::Value", "getViewSource">, + InterfaceMethod< + /*desc=*/[{ Returns the buffer which the view created. }], + /*retTy=*/"::mlir::Value", + /*methodName=*/"getViewDest", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op->getResult(0); + }] + > ]; } diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp index 6cece4630a0e5..8062b474539fd 100644 --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -127,9 +127,12 @@ static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth, Operation *op = result.getOwner(); // If this is a view, unwrap to the source. - if (ViewLikeOpInterface view = dyn_cast(op)) - return collectUnderlyingAddressValues(view.getViewSource(), maxDepth, - visited, output); + if (ViewLikeOpInterface view = dyn_cast(op)) { + if (result == view.getViewDest()) { + return collectUnderlyingAddressValues(view.getViewSource(), maxDepth, + visited, output); + } + } // Check to see if we can reason about the control flow of this op. if (auto branch = dyn_cast(op)) { return collectUnderlyingAddressValues(branch, /*region=*/nullptr, result, diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 6c9adff7e9106..3ab5de2372278 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -190,7 +190,8 @@ static bool isEscapingMemref(Value memref, Block *block) { // Check if this is defined to be an alias of another memref. if (auto viewOp = dyn_cast(defOp)) - if (isEscapingMemref(viewOp.getViewSource(), block)) + if (memref == viewOp.getViewDest() && + isEscapingMemref(viewOp.getViewSource(), block)) return true; // Any op besides allocating ops wouldn't guarantee alias freedom diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index f1f12f4bca70e..56ff2121e4620 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -463,8 +463,12 @@ struct SimplifyClones : public OpRewritePattern { // which otherwise could prevent removal of unnecessary allocs. Value canonicalSource = source; while (auto iface = dyn_cast_or_null( - canonicalSource.getDefiningOp())) + canonicalSource.getDefiningOp())) { + if (canonicalSource != iface.getViewDest()) { + break; + } canonicalSource = iface.getViewSource(); + } std::optional maybeCloneDeallocOp = memref::findDealloc(cloneOp.getOutput()); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp index 891652670dd5b..a465c957d063e 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -37,8 +37,12 @@ using namespace mlir::bufferization; /// Given a memref value, return the "base" value by skipping over all /// ViewLikeOpInterface ops (if any) in the reverse use-def chain. static Value getViewBase(Value value) { - while (auto viewLikeOp = value.getDefiningOp()) + while (auto viewLikeOp = value.getDefiningOp()) { + if (value != viewLikeOp.getViewDest()) { + break; + } value = viewLikeOp.getViewSource(); + } return value; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp index 8f983ab1eae36..0b2e080e52b75 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp @@ -121,7 +121,7 @@ void BufferViewFlowAnalysis::build(Operation *op) { // Add additional dependencies created by view changes to the alias list. if (auto viewInterface = dyn_cast(op)) { registerDependencies(viewInterface.getViewSource(), - viewInterface->getResult(0)); + viewInterface.getViewDest()); return WalkResult::advance(); } @@ -231,8 +231,12 @@ static bool isFunctionArgument(Value v) { /// Given a memref value, return the "base" value by skipping over all /// ViewLikeOpInterface ops (if any) in the reverse use-def chain. static Value getViewBase(Value value) { - while (auto viewLikeOp = value.getDefiningOp()) + while (auto viewLikeOp = value.getDefiningOp()) { + if (value != viewLikeOp.getViewDest()) { + break; + } value = viewLikeOp.getViewSource(); + } return value; } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp index 18f85b6f15bd4..4ea2ac957fa1a 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp @@ -235,8 +235,10 @@ getUnderlyingObjectSet(Value pointerValue) { WalkContinuation walkResult = walkSlice(pointerValue, [&](Value val) { // Attempt to advance to the source of the underlying view-like operation. // Examples of view-like operations include GEPOp and AddrSpaceCastOp. - if (auto viewOp = val.getDefiningOp()) - return WalkContinuation::advanceTo(viewOp.getViewSource()); + if (auto viewOp = val.getDefiningOp()) { + if (val == viewOp.getViewDest()) + return WalkContinuation::advanceTo(viewOp.getViewSource()); + } // Attempt to advance to control flow predecessors. std::optional> controlFlowPredecessors = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 922b7d69e46c1..36434cf2d2ae2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -166,8 +166,12 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, Value source = transferRead.getBase(); // Skip view-like Ops and retrive the actual soruce Operation - while (auto srcOp = source.getDefiningOp()) - source = srcOp.getViewSource(); + while (auto viewLike = source.getDefiningOp()) { + if (viewLike.getViewDest() != source) { + break; + } + source = viewLike.getViewSource(); + } llvm::SmallVector users(source.getUsers().begin(), source.getUsers().end()); @@ -178,7 +182,8 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, if (!processed.insert(user).second) continue; if (auto viewLike = dyn_cast(user)) { - users.append(viewLike->getUsers().begin(), viewLike->getUsers().end()); + Value viewDest = viewLike.getViewDest(); + users.append(viewDest.getUsers().begin(), viewDest.getUsers().end()); continue; } if (isMemoryEffectFree(user) || isa(user)) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 9771bd2aaa143..d35566a9c0d29 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -959,7 +959,7 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp PatternRewriter &rewriter) const override { auto viewLikeOp = extractOp.getSource().getDefiningOp(); - if (!viewLikeOp) + if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest()) return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); rewriter.modifyOpInPlace(extractOp, [&]() { extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index 5af46a48f124f..3de9c3898c713 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -210,8 +210,10 @@ MemrefValue skipFullyAliasingOperations(MemrefValue source) { MemrefValue skipViewLikeOps(MemrefValue source) { while (auto op = source.getDefiningOp()) { if (auto viewLike = dyn_cast(op)) { - source = cast(viewLike.getViewSource()); - continue; + if (source == viewLike.getViewDest()) { + source = cast(viewLike.getViewSource()); + continue; + } } return source; } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index c707f38d9081c..369857fac7a06 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -98,8 +98,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { // If the user has already been processed skip. if (!processed.insert(user).second) continue; - if (isa(user)) { - users.append(user->getUsers().begin(), user->getUsers().end()); + if (auto viewLike = dyn_cast(user)) { + Value viewDest = viewLike.getViewDest(); + users.append(viewDest.getUsers().begin(), viewDest.getUsers().end()); continue; } if (isMemoryEffectFree(user)) @@ -182,8 +183,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { // If the user has already been processed skip. if (!processed.insert(user).second) continue; - if (isa(user)) { - users.append(user->getUsers().begin(), user->getUsers().end()); + if (auto viewLike = dyn_cast(user)) { + Value viewDest = viewLike.getViewDest(); + users.append(viewDest.getUsers().begin(), viewDest.getUsers().end()); continue; } if (isMemoryEffectFree(user) || isa(user))