@@ -90,14 +90,19 @@ namespace {
9090// / Note that an alternative is to transform it to linalg.transpose +
9191// / vector.transfer_read to do the transpose in memory instead.
9292struct TransferReadPermutationLowering
93- : public OpRewritePattern <vector::TransferReadOp> {
94- using OpRewritePattern::OpRewritePattern ;
93+ : public MaskableOpRewritePattern <vector::TransferReadOp> {
94+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
9595
96- LogicalResult matchAndRewrite (vector::TransferReadOp op,
97- PatternRewriter &rewriter) const override {
96+ FailureOr<mlir::Value>
97+ matchAndRewriteMaskableOp (vector::TransferReadOp op,
98+ MaskingOpInterface maskOp,
99+ PatternRewriter &rewriter) const override {
98100 // TODO: support 0-d corner case.
99101 if (op.getTransferRank () == 0 )
100102 return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
103+ // TODO: Support transfer_read inside MaskOp case.
104+ if (maskOp)
105+ return rewriter.notifyMatchFailure (op, " Masked case not supported" );
101106
102107 SmallVector<unsigned > permutation;
103108 AffineMap map = op.getPermutationMap ();
@@ -142,9 +147,9 @@ struct TransferReadPermutationLowering
142147
143148 // Transpose result of transfer_read.
144149 SmallVector<int64_t > transposePerm (permutation.begin (), permutation.end ());
145- rewriter. replaceOpWithNewOp <vector::TransposeOp>(op, newRead,
146- transposePerm);
147- return success ();
150+ return rewriter
151+ . create <vector::TransposeOp>(op. getLoc (), newRead, transposePerm)
152+ . getResult ();
148153 }
149154};
150155
@@ -165,14 +170,19 @@ struct TransferReadPermutationLowering
165170// / %v = vector.transfer_write %tmp ...
166171// / permutation_map: (d0, d1, d2, d3) -> (d2, d3)
167172struct TransferWritePermutationLowering
168- : public OpRewritePattern <vector::TransferWriteOp> {
169- using OpRewritePattern::OpRewritePattern ;
173+ : public MaskableOpRewritePattern <vector::TransferWriteOp> {
174+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
170175
171- LogicalResult matchAndRewrite (vector::TransferWriteOp op,
172- PatternRewriter &rewriter) const override {
176+ FailureOr<mlir::Value>
177+ matchAndRewriteMaskableOp (vector::TransferWriteOp op,
178+ MaskingOpInterface maskOp,
179+ PatternRewriter &rewriter) const override {
173180 // TODO: support 0-d corner case.
174181 if (op.getTransferRank () == 0 )
175182 return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
183+ // TODO: Support transfer_write inside MaskOp case.
184+ if (maskOp)
185+ return rewriter.notifyMatchFailure (op, " Masked case not supported" );
176186
177187 SmallVector<unsigned > permutation;
178188 AffineMap map = op.getPermutationMap ();
@@ -207,11 +217,14 @@ struct TransferWritePermutationLowering
207217 op.getLoc (), op.getVector (), indices);
208218 auto newMap = AffineMap::getMinorIdentityMap (
209219 map.getNumDims (), map.getNumResults (), rewriter.getContext ());
210- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
211- op, newVec, op.getSource (), op.getIndices (), AffineMapAttr::get (newMap),
212- op.getMask (), newInBoundsAttr);
213-
214- return success ();
220+ auto newWrite = rewriter.create <vector::TransferWriteOp>(
221+ op.getLoc (), newVec, op.getSource (), op.getIndices (),
222+ AffineMapAttr::get (newMap), op.getMask (), newInBoundsAttr);
223+ if (newWrite.hasPureTensorSemantics ())
224+ return newWrite.getResult ();
225+ // In the memref case there's no return value. Use empty value to signal
226+ // success.
227+ return Value ();
215228 }
216229};
217230
@@ -231,14 +244,19 @@ struct TransferWritePermutationLowering
231244// / vector<1x8x16xf32>
232245// / ```
233246struct TransferWriteNonPermutationLowering
234- : public OpRewritePattern <vector::TransferWriteOp> {
235- using OpRewritePattern::OpRewritePattern ;
247+ : public MaskableOpRewritePattern <vector::TransferWriteOp> {
248+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
236249
237- LogicalResult matchAndRewrite (vector::TransferWriteOp op,
238- PatternRewriter &rewriter) const override {
250+ FailureOr<mlir::Value>
251+ matchAndRewriteMaskableOp (vector::TransferWriteOp op,
252+ MaskingOpInterface maskOp,
253+ PatternRewriter &rewriter) const override {
239254 // TODO: support 0-d corner case.
240255 if (op.getTransferRank () == 0 )
241256 return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
257+ // TODO: Support transfer_write inside MaskOp case.
258+ if (maskOp)
259+ return rewriter.notifyMatchFailure (op, " Masked case not supported" );
242260
243261 SmallVector<unsigned > permutation;
244262 AffineMap map = op.getPermutationMap ();
@@ -285,10 +303,14 @@ struct TransferWriteNonPermutationLowering
285303 newInBoundsValues.push_back (op.isDimInBounds (i));
286304 }
287305 ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr (newInBoundsValues);
288- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
289- op, newVec, op.getSource (), op.getIndices (), AffineMapAttr::get (newMap),
290- newMask, newInBoundsAttr);
291- return success ();
306+ auto newWrite = rewriter.create <vector::TransferWriteOp>(
307+ op.getLoc (), newVec, op.getSource (), op.getIndices (),
308+ AffineMapAttr::get (newMap), newMask, newInBoundsAttr);
309+ if (newWrite.hasPureTensorSemantics ())
310+ return newWrite.getResult ();
311+ // In the memref case there's no return value. Use empty value to signal
312+ // success.
313+ return Value ();
292314 }
293315};
294316
0 commit comments