2323#include " mlir/IR/BuiltinAttributes.h"
2424#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2525#include " llvm/ADT/TypeSwitch.h"
26+ #include < cstddef>
27+ #include < sys/_types/_int64_t.h>
2628#include < type_traits>
2729
2830namespace mlir {
@@ -67,6 +69,12 @@ class InsertSliceOfTransferWriteOpFolder final
6769
6870 LogicalResult matchAndRewrite (tensor::InsertSliceOp insertSliceOp,
6971 PatternRewriter &rewriter) const override ;
72+
73+ private:
74+ static bool
75+ doesTransferWriteCoverInsertSlice (vector::TransferWriteOp writeOp,
76+ tensor::InsertSliceOp insertSliceOp,
77+ MLIRContext *context);
7078};
7179} // namespace
7280
@@ -136,6 +144,11 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
136144 if (failed (preconditionResult))
137145 return preconditionResult;
138146
147+ if (!doesTransferWriteCoverInsertSlice (writeOp, insertSliceOp,
148+ rewriter.getContext ()))
149+ return rewriter.notifyMatchFailure (
150+ insertSliceOp, " transfer_write does not cover insert_slice" );
151+
139152 SmallVector<Value> indices (writeOp.getIndices ().begin (),
140153 writeOp.getIndices ().end ());
141154 SmallVector<Value> sourceIndices;
@@ -154,6 +167,13 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
154167 return success ();
155168}
156169
170+ bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice (
171+ vector::TransferWriteOp writeOp, tensor::InsertSliceOp insertSliceOp,
172+ MLIRContext *context) {
173+ // Todo
174+ return true ;
175+ }
176+
157177template <typename OpTy>
158178struct InsertSliceOfInsertSliceFolder : public OpRewritePattern <OpTy> {
159179 using OpRewritePattern<OpTy>::OpRewritePattern;
0 commit comments