1818#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
1919#include " mlir/Dialect/Utils/IndexingUtils.h"
2020#include " mlir/Dialect/Vector/IR/VectorOps.h"
21+ #include " mlir/Dialect/Vector/Utils/VectorUtils.h"
2122#include " mlir/IR/AffineMap.h"
2223#include " mlir/IR/BuiltinAttributes.h"
2324#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -48,12 +49,14 @@ static Value getTensorOperand(tensor::InsertSliceOp op) {
4849namespace {
4950// / Merge extract_slice operation with load/transferRead operation.
5051class TransferReadOfExtractSliceOpFolder final
51- : public OpRewritePattern <vector::TransferReadOp> {
52+ : public vector::MaskableOpRewritePattern <vector::TransferReadOp> {
5253public:
53- using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern ;
54+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
5455
55- LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
56- PatternRewriter &rewriter) const override ;
56+ FailureOr<mlir::Value>
57+ matchAndRewriteMaskableOp (vector::TransferReadOp readOp,
58+ vector::MaskingOpInterface maskOp,
59+ PatternRewriter &rewriter) const override ;
5760};
5861
5962// / Merge insert_slice operation with store/transferWriteOp operation.
@@ -84,8 +87,10 @@ static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
8487 return success ();
8588}
8689
87- LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite (
88- vector::TransferReadOp readOp, PatternRewriter &rewriter) const {
90+ FailureOr<mlir::Value>
91+ TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp (
92+ vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
93+ PatternRewriter &rewriter) const {
8994 auto extractSliceOp =
9095 getTensorOperand (readOp).getDefiningOp <tensor::ExtractSliceOp>();
9196 if (!extractSliceOp)
@@ -95,7 +100,7 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
95100 preconditionsFoldExtractOrInsertWithTransferOp (rewriter, readOp,
96101 extractSliceOp);
97102 if (failed (preconditionResult))
98- return preconditionResult ;
103+ return rewriter. notifyMatchFailure (readOp, " Failed preconditions " ) ;
99104
100105 SmallVector<Value> indices (readOp.getIndices ().begin (),
101106 readOp.getIndices ().end ());
@@ -105,15 +110,17 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
105110 extractSliceOp.getMixedStrides (), extractSliceOp.getDroppedDims (),
106111 indices, sourceIndices);
107112
108- rewriter.replaceOpWithNewOp <vector::TransferReadOp>(
109- readOp, readOp.getVectorType (), extractSliceOp.getSource (), sourceIndices,
113+ Operation *newOp = rewriter.create <vector::TransferReadOp>(
114+ readOp.getLoc (), readOp.getVectorType (), extractSliceOp.getSource (),
115+ sourceIndices,
110116 AffineMapAttr::get (expandDimsToRank (
111117 readOp.getPermutationMap (), extractSliceOp.getSourceType ().getRank (),
112118 extractSliceOp.getDroppedDims ())),
113119 readOp.getPadding (),
114120 /* mask=*/ Value (), readOp.getInBoundsAttr ());
115-
116- return success ();
121+ if (maskOp)
122+ newOp = mlir::vector::maskOperation (rewriter, newOp, maskOp.getMask ());
123+ return newOp->getResults ()[0 ];
117124}
118125
119126LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite (
0 commit comments