Skip to content

Commit 7dace4e

Browse files
committed
Updates according to comments
1 parent 98cee5c commit 7dace4e

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
130130
})
131131
.Case<vector::ConstantMaskOp>(
132132
[&](auto constantMaskOp) -> std::optional<Operation *> {
133+
// Take the shape of mask, compress its trailing dimension:
133134
SmallVector<int64_t> maskDimSizes(
134135
constantMaskOp.getMaskDimSizes());
135136
int64_t &maskIndex = maskDimSizes.back();
@@ -586,6 +587,10 @@ struct ConvertVectorMaskedLoad final
586587
LogicalResult
587588
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
588589
ConversionPatternRewriter &rewriter) const override {
590+
// See #115653
591+
if (op.getVectorType().getRank() != 1)
592+
return rewriter.notifyMatchFailure(op,
593+
"only 1-D vectors are supported ATM");
589594

590595
auto loc = op.getLoc();
591596
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,13 @@ func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>)
134134
}
135135

136136
// CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
137-
// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]
138-
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<15xi8>
139137
// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
140-
// CHECK: %[[EXT_ORIG_MASK:.+]] = vector.extract %[[ORIG_MASK]][1]
138+
// CHECK: vector.extract %[[ORIG_MASK]][1]
141139

142140
// Compressing the mask used for emulated masked load.
143141
// The innermost dimension is compressed to 2 elements from 5.
144-
// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
145-
// CHECK: %[[EXT_NEW_MASK:.+]] = vector.extract %[[NEW_MASK]][1]
142+
// CHECK: %[[NEW_COMPRESSED_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
143+
// CHECK: vector.extract %[[NEW_COMPRESSED_MASK]][1]
146144

147145
// -----
148146

0 commit comments

Comments
 (0)