2121#include " mlir/Dialect/Vector/Utils/VectorUtils.h"
2222#include " mlir/IR/AffineMap.h"
2323#include " mlir/IR/BuiltinAttributes.h"
24+ #include " mlir/Interfaces/ValueBoundsOpInterface.h"
2425#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2526#include " llvm/ADT/TypeSwitch.h"
2627#include < type_traits>
@@ -67,6 +68,12 @@ class InsertSliceOfTransferWriteOpFolder final
6768
6869 LogicalResult matchAndRewrite (tensor::InsertSliceOp insertSliceOp,
6970 PatternRewriter &rewriter) const override ;
71+
72+ private:
73+ static bool
74+ doesTransferWriteCoverInsertSlice (vector::TransferWriteOp writeOp,
75+ tensor::InsertSliceOp insertSliceOp,
76+ MLIRContext *context);
7077};
7178} // namespace
7279
@@ -136,6 +143,11 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
136143 if (failed (preconditionResult))
137144 return preconditionResult;
138145
146+ if (!doesTransferWriteCoverInsertSlice (writeOp, insertSliceOp,
147+ rewriter.getContext ()))
148+ return rewriter.notifyMatchFailure (
149+ insertSliceOp, " transfer_write does not cover insert_slice" );
150+
139151 SmallVector<Value> indices (writeOp.getIndices ().begin (),
140152 writeOp.getIndices ().end ());
141153 SmallVector<Value> sourceIndices;
@@ -154,6 +166,25 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
154166 return success ();
155167}
156168
169+ bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice (
170+ vector::TransferWriteOp writeOp, tensor::InsertSliceOp insertSliceOp,
171+ MLIRContext *context) {
172+ auto destType = cast<ShapedType>(writeOp.getOperand (0 ).getType ());
173+ auto insertSliceType = insertSliceOp.getSourceType ();
174+
175+ if (destType.hasStaticShape () && insertSliceType.hasStaticShape ()) {
176+ for (int64_t d = 0 , e = insertSliceType.getRank (); d < e; ++d) {
177+ if (destType.getDimSize (d) != insertSliceType.getDimSize (d))
178+ return false ;
179+ }
180+ return true ;
181+ }
182+
183+ // Todo: ValueBoundsConstraintSet for dynamic shapes.
184+
185+ return true ;
186+ }
187+
157188template <typename OpTy>
158189struct InsertSliceOfInsertSliceFolder : public OpRewritePattern <OpTy> {
159190 using OpRewritePattern<OpTy>::OpRewritePattern;
0 commit comments