Skip to content

Commit 98cee5c

Browse files
committed
[MLIR] Refactor mask compression logic when emulating vector.maskedload ops
This patch simplifies and extends the logic used when compressing masks emitted by `vector.constant_mask` to support extracting 1-D vectors from multi-dimensional vector loads. It streamlines mask computation, making it applicable for multi-dimensional mask generation, improving the overall handling of masked load operations.
1 parent 2ed8c5d commit 98cee5c

File tree

2 files changed

+45
-46
lines changed

2 files changed

+45
-46
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -128,34 +128,16 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
128128
return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
129129
newMaskOperands);
130130
})
131-
.Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
132-
-> std::optional<Operation *> {
133-
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
134-
size_t numMaskOperands = maskDimSizes.size();
135-
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
136-
int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
137-
int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
138-
numSrcElemsPerDest);
139-
140-
// TODO: we only want the mask between [startIndex, maskIndex]
141-
// to be true, the rest are false.
142-
if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
143-
return std::nullopt;
144-
145-
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
146-
newMaskDimSizes.push_back(maskIndex);
147-
148-
if (numFrontPadElems == 0)
149-
return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
150-
newMaskDimSizes);
151-
152-
SmallVector<bool> newMaskValues;
153-
for (int64_t i = 0; i < numDestElems; ++i)
154-
newMaskValues.push_back(i >= startIndex && i < maskIndex);
155-
auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
156-
return rewriter.create<arith::ConstantOp>(loc, newMaskType,
157-
newMask);
158-
})
131+
.Case<vector::ConstantMaskOp>(
132+
[&](auto constantMaskOp) -> std::optional<Operation *> {
133+
SmallVector<int64_t> maskDimSizes(
134+
constantMaskOp.getMaskDimSizes());
135+
int64_t &maskIndex = maskDimSizes.back();
136+
maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
137+
numSrcElemsPerDest);
138+
return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
139+
maskDimSizes);
140+
})
159141
.Case<arith::ConstantOp>([&](auto constantOp)
160142
-> std::optional<Operation *> {
161143
// TODO: Support multiple dimensions.
@@ -605,11 +587,6 @@ struct ConvertVectorMaskedLoad final
605587
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
606588
ConversionPatternRewriter &rewriter) const override {
607589

608-
// See #115653
609-
if (op.getVectorType().getRank() != 1)
610-
return rewriter.notifyMatchFailure(op,
611-
"only 1-D vectors are supported ATM");
612-
613590
auto loc = op.getLoc();
614591
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
615592
Type oldElementType = op.getType().getElementType();

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,20 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
4242

4343
// -----
4444

45-
func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
45+
func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector<5xi2> {
4646
%0 = memref.alloc() : memref<3x5xi2>
47-
%cst = arith.constant dense<0> : vector<3x5xi2>
4847
%mask = vector.constant_mask [3] : vector<5xi1>
4948
%c0 = arith.constant 0 : index
5049
%c2 = arith.constant 2 : index
5150
%1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
5251
memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
53-
%2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
54-
return %2 : vector<3x5xi2>
52+
return %1 : vector<5xi2>
5553
}
56-
5754
// CHECK-LABEL: func @vector_constant_mask_maskedload_i2(
58-
// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
55+
// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<5xi2>
56+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
5957
// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
60-
// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
58+
// CHECK: %[[NEWMASK:.+]] = vector.constant_mask [2] : vector<2xi1>
6159
// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
6260
// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
6361
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
@@ -123,6 +121,31 @@ func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vec
123121

124122
// -----
125123

124+
// This test is similar to @vector_constant_mask_maskedload_i2, but the mask is multi-dimensional.
125+
func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>) -> vector<5xi2> {
126+
%0 = memref.alloc() : memref<4x3x5xi2>
127+
%mask = vector.constant_mask [2, 2] : vector<3x5xi1>
128+
%ext_mask = vector.extract %mask[1] : vector<5xi1> from vector<3x5xi1>
129+
%c0 = arith.constant 0 : index
130+
%c2 = arith.constant 2 : index
131+
%1 = vector.maskedload %0[%c2, %c0, %c0], %ext_mask, %passthru :
132+
memref<4x3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
133+
return %1 : vector<5xi2>
134+
}
135+
136+
// CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
137+
// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]
138+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<15xi8>
139+
// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
140+
// CHECK: %[[EXT_ORIG_MASK:.+]] = vector.extract %[[ORIG_MASK]][1]
141+
142+
// Compressing the mask used for emulated masked load.
143+
// The innermost dimension is compressed to 2 elements from 5.
144+
// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
145+
// CHECK: %[[EXT_NEW_MASK:.+]] = vector.extract %[[NEW_MASK]][1]
146+
147+
// -----
148+
126149
func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
127150
%0 = memref.alloc() : memref<3x3xi2>
128151
%cst = arith.constant dense<0> : vector<3x3xi2>
@@ -252,7 +275,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
252275
// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
253276
// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
254277
// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
255-
// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
278+
// CHECK: %[[ONE:.+]] = vector.constant_mask [2] : vector<2xi1>
256279
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
257280

258281
// Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
@@ -301,7 +324,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
301324

302325
// -----
303326

304-
func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
327+
func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
305328
%0 = memref.alloc() : memref<3x5xi2>
306329
%mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
307330
%c0 = arith.constant 0 : index
@@ -311,24 +334,23 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
311334
return %1 : vector<5xi2>
312335
}
313336

314-
// CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
337+
// CHECK: func @vector_maskedload_i2_constant_mask_unaligned(
315338
// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
316339
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
317340
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
318341

342+
// Emulated masked load from alloc:
319343
// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
320344
// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
321345
// CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
322346
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
323-
324-
// Emulated masked load from alloc:
325347
// CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
326348
// CHECK: %[[C1:.+]] = arith.constant 1 : index
327349
// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
328350
// CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
329351

330352
// Select from emulated loaded vector and passthru vector:
331-
// TODO: fold this part if possible.
353+
// TODO: fold insert_strided_slice into source if possible.
332354
// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
333355
// CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
334356
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>

0 commit comments

Comments
 (0)