1414#include " mlir/Dialect/Affine/IR/AffineOps.h"
1515#include " mlir/Dialect/Arith/IR/Arith.h"
1616#include " mlir/Dialect/MemRef/IR/MemRef.h"
17+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1718#include " mlir/Dialect/Tensor/IR/Tensor.h"
1819#include " mlir/Dialect/Utils/IndexingUtils.h"
1920#include " mlir/Dialect/Vector/IR/VectorOps.h"
@@ -104,10 +105,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
104105 << " \n " );
105106 llvm::SmallVector<Operation *, 8 > blockingAccesses;
106107 Operation *firstOverwriteCandidate = nullptr ;
107- Value source = write.getSource ();
108- // Skip subview ops.
109- while (auto subView = source.getDefiningOp <memref::SubViewOp>())
110- source = subView.getSource ();
108+ Value source =
109+ memref::skipSubViewsAndCasts (cast<MemrefValue>(write.getSource ()));
111110 llvm::SmallVector<Operation *, 32 > users (source.getUsers ().begin (),
112111 source.getUsers ().end ());
113112 llvm::SmallDenseSet<Operation *, 32 > processed;
@@ -116,8 +115,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
116115 // If the user has already been processed skip.
117116 if (!processed.insert (user).second )
118117 continue ;
119- if (auto subView = dyn_cast <memref::SubViewOp>(user)) {
120- users.append (subView ->getUsers ().begin (), subView ->getUsers ().end ());
118+ if (isa <memref::SubViewOp, memref::CastOp >(user)) {
119+ users.append (user ->getUsers ().begin (), user ->getUsers ().end ());
121120 continue ;
122121 }
123122 if (isMemoryEffectFree (user))
@@ -126,7 +125,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
126125 continue ;
127126 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
128127 // Check candidate that can override the store.
129- if (write.getSource () == nextWrite.getSource () &&
128+ if (memref::isSameViewOrTrivialAlias (
129+ cast<MemrefValue>(nextWrite.getSource ()),
130+ cast<MemrefValue>(write.getSource ())) &&
130131 checkSameValueWAW (nextWrite, write) &&
131132 postDominators.postDominates (nextWrite, write)) {
132133 if (firstOverwriteCandidate == nullptr ||
@@ -191,10 +192,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
191192 << " \n " );
192193 SmallVector<Operation *, 8 > blockingWrites;
193194 vector::TransferWriteOp lastwrite = nullptr ;
194- Value source = read.getSource ();
195- // Skip subview ops.
196- while (auto subView = source.getDefiningOp <memref::SubViewOp>())
197- source = subView.getSource ();
195+ Value source =
196+ memref::skipSubViewsAndCasts (cast<MemrefValue>(read.getSource ()));
198197 llvm::SmallVector<Operation *, 32 > users (source.getUsers ().begin (),
199198 source.getUsers ().end ());
200199 llvm::SmallDenseSet<Operation *, 32 > processed;
@@ -203,12 +202,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
203202 // If the user has already been processed skip.
204203 if (!processed.insert (user).second )
205204 continue ;
206- if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
207- users.append (subView->getUsers ().begin (), subView->getUsers ().end ());
208- continue ;
209- }
210- if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
211- users.append (collapsed->getUsers ().begin (), collapsed->getUsers ().end ());
205+ if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
206+ users.append (user->getUsers ().begin (), user->getUsers ().end ());
212207 continue ;
213208 }
214209 if (isMemoryEffectFree (user) || isa<vector::TransferReadOp>(user))
@@ -221,7 +216,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
221216 cast<VectorTransferOpInterface>(read.getOperation ()),
222217 /* testDynamicValueUsingBounds=*/ true ))
223218 continue ;
224- if (write.getSource () == read.getSource () &&
219+ if (memref::isSameViewOrTrivialAlias (
220+ cast<MemrefValue>(read.getSource ()),
221+ cast<MemrefValue>(write.getSource ())) &&
225222 dominators.dominates (write, read) && checkSameValueRAW (write, read)) {
226223 if (lastwrite == nullptr || dominators.dominates (lastwrite, write))
227224 lastwrite = write;
0 commit comments