Skip to content

Commit 1669ac4

Browse files
authored
[MLIR] Refactor mask compression logic when emulating vector.maskedload ops (#116520)
This patch simplifies and extends the logic used when compressing masks emitted by `vector.constant_mask` to support extracting 1-D vectors from multi-dimensional vector loads. It streamlines mask computation, making it applicable for multi-dimensional mask generation, improving the overall handling of masked load operations.
1 parent 1d810ec commit 1669ac4

File tree

2 files changed

+44
-42
lines changed

2 files changed

+44
-42
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -128,34 +128,17 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
128128
return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
129129
newMaskOperands);
130130
})
131-
.Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
132-
-> std::optional<Operation *> {
133-
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
134-
size_t numMaskOperands = maskDimSizes.size();
135-
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
136-
int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
137-
int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
138-
numSrcElemsPerDest);
139-
140-
// TODO: we only want the mask between [startIndex, maskIndex]
141-
// to be true, the rest are false.
142-
if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
143-
return std::nullopt;
144-
145-
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
146-
newMaskDimSizes.push_back(maskIndex);
147-
148-
if (numFrontPadElems == 0)
149-
return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
150-
newMaskDimSizes);
151-
152-
SmallVector<bool> newMaskValues;
153-
for (int64_t i = 0; i < numDestElems; ++i)
154-
newMaskValues.push_back(i >= startIndex && i < maskIndex);
155-
auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
156-
return rewriter.create<arith::ConstantOp>(loc, newMaskType,
157-
newMask);
158-
})
131+
.Case<vector::ConstantMaskOp>(
132+
[&](auto constantMaskOp) -> std::optional<Operation *> {
133+
// Take the shape of mask, compress its trailing dimension:
134+
SmallVector<int64_t> maskDimSizes(
135+
constantMaskOp.getMaskDimSizes());
136+
int64_t &maskIndex = maskDimSizes.back();
137+
maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
138+
numSrcElemsPerDest);
139+
return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
140+
maskDimSizes);
141+
})
159142
.Case<arith::ConstantOp>([&](auto constantOp)
160143
-> std::optional<Operation *> {
161144
// TODO: Support multiple dimensions.
@@ -604,7 +587,6 @@ struct ConvertVectorMaskedLoad final
604587
LogicalResult
605588
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
606589
ConversionPatternRewriter &rewriter) const override {
607-
608590
// See #115653
609591
if (op.getVectorType().getRank() != 1)
610592
return rewriter.notifyMatchFailure(op,

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,20 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
4242

4343
// -----
4444

45-
func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
45+
func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector<5xi2> {
4646
%0 = memref.alloc() : memref<3x5xi2>
47-
%cst = arith.constant dense<0> : vector<3x5xi2>
4847
%mask = vector.constant_mask [3] : vector<5xi1>
4948
%c0 = arith.constant 0 : index
5049
%c2 = arith.constant 2 : index
5150
%1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
5251
memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
53-
%2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
54-
return %2 : vector<3x5xi2>
52+
return %1 : vector<5xi2>
5553
}
56-
5754
// CHECK-LABEL: func @vector_constant_mask_maskedload_i2(
58-
// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
55+
// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<5xi2>
56+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
5957
// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
60-
// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
58+
// CHECK: %[[NEWMASK:.+]] = vector.constant_mask [2] : vector<2xi1>
6159
// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
6260
// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
6361
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
@@ -123,6 +121,29 @@ func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vec
123121

124122
// -----
125123

124+
// This test is similar to @vector_constant_mask_maskedload_i2, but the mask is multi-dimensional.
125+
func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>) -> vector<5xi2> {
126+
%0 = memref.alloc() : memref<4x3x5xi2>
127+
%mask = vector.constant_mask [2, 2] : vector<3x5xi1>
128+
%ext_mask = vector.extract %mask[1] : vector<5xi1> from vector<3x5xi1>
129+
%c0 = arith.constant 0 : index
130+
%c2 = arith.constant 2 : index
131+
%1 = vector.maskedload %0[%c2, %c0, %c0], %ext_mask, %passthru :
132+
memref<4x3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
133+
return %1 : vector<5xi2>
134+
}
135+
136+
// CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
137+
// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
138+
// CHECK: vector.extract %[[ORIG_MASK]][1]
139+
140+
// Compressing the mask used for emulated masked load.
141+
// The innermost dimension is compressed to 2 elements from 5.
142+
// CHECK: %[[NEW_COMPRESSED_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
143+
// CHECK: vector.extract %[[NEW_COMPRESSED_MASK]][1]
144+
145+
// -----
146+
126147
func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
127148
%0 = memref.alloc() : memref<3x3xi2>
128149
%cst = arith.constant dense<0> : vector<3x3xi2>
@@ -252,7 +273,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
252273
// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
253274
// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
254275
// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
255-
// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
276+
// CHECK: %[[ONE:.+]] = vector.constant_mask [2] : vector<2xi1>
256277
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
257278

258279
// Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
@@ -301,7 +322,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
301322

302323
// -----
303324

304-
func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
325+
func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
305326
%0 = memref.alloc() : memref<3x5xi2>
306327
%mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
307328
%c0 = arith.constant 0 : index
@@ -311,24 +332,23 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
311332
return %1 : vector<5xi2>
312333
}
313334

314-
// CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
335+
// CHECK: func @vector_maskedload_i2_constant_mask_unaligned(
315336
// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
316337
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
317338
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
318339

340+
// Emulated masked load from alloc:
319341
// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
320342
// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
321343
// CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
322344
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
323-
324-
// Emulated masked load from alloc:
325345
// CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
326346
// CHECK: %[[C1:.+]] = arith.constant 1 : index
327347
// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
328348
// CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
329349

330350
// Select from emulated loaded vector and passthru vector:
331-
// TODO: fold this part if possible.
351+
// TODO: fold insert_strided_slice into source if possible.
332352
// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
333353
// CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
334354
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>

0 commit comments

Comments
 (0)