Skip to content

Commit 814379d

Browse files
committed
fixup! [mlir][linalg] Add support for masked vectorization of tensor.insert_slice (2/N)
Address comment from Hanhan
1 parent 1abdf4f commit 814379d

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2743,29 +2743,25 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
27432743
sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
27442744
}
27452745

2746-
// 3.a. TransferReadOp
27472746
SmallVector<Value> readIndices(
27482747
vecType.getRank(),
27492748
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
27502749
Operation *read = rewriter.create<vector::TransferReadOp>(
27512750
sliceOp.getLoc(), vecType, source, readIndices, padValue,
27522751
ArrayRef<bool>{readInBounds});
27532752

2754-
// Mask the xfer_read Op
2755-
if (!inputVectorSizes.empty()) {
2753+
if (maskOp) {
27562754
read = mlir::vector::maskOperation(rewriter, read, maskOp);
27572755
}
27582756

2759-
// 3.b. TransferWriteOp
27602757
auto writeIndices = getValueOrCreateConstantIndexOp(
27612758
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
27622759

27632760
Operation *write = rewriter.create<vector::TransferWriteOp>(
27642761
sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
27652762
ArrayRef<bool>{writeInBounds});
27662763

2767-
// Mask the xfer_write Op
2768-
if (!inputVectorSizes.empty()) {
2764+
if (maskOp) {
27692765
write = mlir::vector::maskOperation(rewriter, write, maskOp);
27702766
}
27712767

0 commit comments

Comments
 (0)