@@ -322,14 +322,19 @@ struct TransferWriteNonPermutationLowering
322322// / %v = vector.transfer_read ...
323323// / permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
324324// / vector.broadcast %v
325- struct TransferOpReduceRank : public OpRewritePattern <vector::TransferReadOp> {
326- using OpRewritePattern::OpRewritePattern;
325+ struct TransferOpReduceRank
326+ : public MaskableOpRewritePattern<vector::TransferReadOp> {
327+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
327328
328- LogicalResult matchAndRewrite (vector::TransferReadOp op,
329- PatternRewriter &rewriter) const override {
329+ FailureOr<mlir::Value>
330+ matchAndRewriteMaskableOp (vector::TransferReadOp op,
331+ MaskingOpInterface maskOp,
332+ PatternRewriter &rewriter) const override {
330333 // TODO: support 0-d corner case.
331334 if (op.getTransferRank () == 0 )
332335 return rewriter.notifyMatchFailure (op, " 0-d corner case not supported" );
336+ if (maskOp)
337+ return rewriter.notifyMatchFailure (op, " Masked case not supported" );
333338
334339 AffineMap map = op.getPermutationMap ();
335340 unsigned numLeadingBroadcast = 0 ;
@@ -369,9 +374,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
369374 op.getLoc (), originalVecType.getElementType (), op.getSource (),
370375 op.getIndices ());
371376 }
372- rewriter. replaceOpWithNewOp <vector::BroadcastOp>(op, originalVecType,
373- newRead);
374- return success ();
377+ return rewriter
378+ . create <vector::BroadcastOp>(op. getLoc (), originalVecType, newRead)
379+ . getVector ();
375380 }
376381
377382 SmallVector<int64_t > newShape (
@@ -393,9 +398,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
393398 op.getLoc (), newReadType, op.getSource (), op.getIndices (),
394399 AffineMapAttr::get (newMap), op.getPadding (), op.getMask (),
395400 newInBoundsAttr);
396- rewriter. replaceOpWithNewOp <vector::BroadcastOp>(op, originalVecType,
397- newRead);
398- return success ();
401+ return rewriter
402+ . create <vector::BroadcastOp>(op. getLoc (), originalVecType, newRead)
403+ . getVector ();
399404 }
400405};
401406
0 commit comments