Skip to content

Commit f9cc356

Browse files
committed
[MLIR] vector.constant_mask to support unaligned cases
In the case of unaligned indexing and `vector.constant_mask`, notice that `numFrontPadElems` is always strictly smaller than `numSrcElemsPerDest`, which means that with a non-zero `numFrontPadElems`, the compressed `constant_mask` op will not have any preceding zeros elements in the innermost dimemsion but the values and size relevant might change due to the extra step of shifting and aligning elements. This patch enables multi-dimensional support by simply observing the abovementioned property and eliminate the constraints.
1 parent aa65473 commit f9cc356

File tree

2 files changed

+57
-35
lines changed

2 files changed

+57
-35
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -128,34 +128,16 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
128128
return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
129129
newMaskOperands);
130130
})
131-
.Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
132-
-> std::optional<Operation *> {
133-
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
134-
size_t numMaskOperands = maskDimSizes.size();
135-
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
136-
int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
137-
int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
138-
numSrcElemsPerDest);
139-
140-
// TODO: we only want the mask between [startIndex, maskIndex]
141-
// to be true, the rest are false.
142-
if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
143-
return std::nullopt;
144-
145-
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
146-
newMaskDimSizes.push_back(maskIndex);
147-
148-
if (numFrontPadElems == 0)
149-
return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
150-
newMaskDimSizes);
151-
152-
SmallVector<bool> newMaskValues;
153-
for (int64_t i = 0; i < numDestElems; ++i)
154-
newMaskValues.push_back(i >= startIndex && i < maskIndex);
155-
auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
156-
return rewriter.create<arith::ConstantOp>(loc, newMaskType,
157-
newMask);
158-
})
131+
.Case<vector::ConstantMaskOp>(
132+
[&](auto constantMaskOp) -> std::optional<Operation *> {
133+
SmallVector<int64_t> maskDimSizes(
134+
constantMaskOp.getMaskDimSizes());
135+
int64_t &maskIndex = maskDimSizes.back();
136+
maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
137+
numSrcElemsPerDest);
138+
return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
139+
maskDimSizes);
140+
})
159141
.Case<arith::ConstantOp>([&](auto constantOp)
160142
-> std::optional<Operation *> {
161143
// TODO: Support multiple dimensions.

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector
5757
// CHECK-LABEL: func @vector_constant_mask_maskedload_i2(
5858
// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
5959
// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
60-
// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
60+
// CHECK: %[[NEWMASK:.+]] = vector.constant_mask [2] : vector<2xi1>
6161
// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
6262
// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
6363
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
@@ -123,6 +123,47 @@ func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vec
123123

124124
// -----
125125

126+
func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>) -> vector<5xi2> {
127+
%0 = memref.alloc() : memref<4x3x5xi2>
128+
%mask = vector.constant_mask [2, 2] : vector<3x5xi1>
129+
%ext_mask = vector.extract %mask[1] : vector<5xi1> from vector<3x5xi1>
130+
%c0 = arith.constant 0 : index
131+
%c2 = arith.constant 2 : index
132+
%1 = vector.maskedload %0[%c2, %c0, %c0], %ext_mask, %passthru :
133+
memref<4x3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
134+
return %1 : vector<5xi2>
135+
}
136+
137+
// CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
138+
// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]
139+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<15xi8>
140+
// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
141+
// CHECK: %[[EXT_ORIG_MASK:.+]] = vector.extract %[[ORIG_MASK]][1]
142+
143+
// compressed mask, used for emulated masked load
144+
// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
145+
// CHECK: %[[EXT_NEW_MASK:.+]] = vector.extract %[[NEW_MASK]][1]
146+
147+
// Create a padded and shifted passthru vector
148+
// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
149+
// CHECK: %[[PADDED_PTH:.+]] = vector.insert_strided_slice %[[PASSTHRU]], %[[EMPTY]]
150+
// CHECK-SAME: {offsets = [2], strides = [1]}
151+
152+
// CHECK: %[[PTH_DOWNCAST:.+]] = vector.bitcast %[[PADDED_PTH]] : vector<8xi2> to vector<2xi8>
153+
// CHECK: %[[C7:.+]] = arith.constant 7 : index
154+
// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C7]]], %[[EXT_NEW_MASK]], %[[PTH_DOWNCAST]]
155+
// CHECK: %[[DOWNCAST_LOAD:.+]] = vector.bitcast %[[MASKLOAD]]
156+
157+
// pad and shift the original mask to match the size and location of the loaded value.
158+
// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
159+
// CHECK: %[[PADDED_MASK:.+]] = vector.insert_strided_slice %[[EXT_ORIG_MASK]], %[[EMPTY_MASK]]
160+
// CHECK-SAME: {offsets = [2], strides = [1]}
161+
// CHECK: %[[SELECT:.+]] = arith.select %[[PADDED_MASK]], %[[DOWNCAST_LOAD]], %[[PADDED_PTH]]
162+
// CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
163+
// CHECK-SAME: {offsets = [2], sizes = [5], strides = [1]}
164+
165+
// -----
166+
126167
func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
127168
%0 = memref.alloc() : memref<3x3xi2>
128169
%cst = arith.constant dense<0> : vector<3x3xi2>
@@ -252,7 +293,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
252293
// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
253294
// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
254295
// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
255-
// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
296+
// CHECK: %[[ONE:.+]] = vector.constant_mask [2] : vector<2xi1>
256297
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
257298

258299
// Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
@@ -301,7 +342,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
301342

302343
// -----
303344

304-
func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
345+
func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
305346
%0 = memref.alloc() : memref<3x5xi2>
306347
%mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
307348
%c0 = arith.constant 0 : index
@@ -311,24 +352,23 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
311352
return %1 : vector<5xi2>
312353
}
313354

314-
// CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
355+
// CHECK: func @vector_maskedload_i2_constant_mask_unaligned(
315356
// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
316357
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
317358
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
318359

360+
// Emulated masked load from alloc:
319361
// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
320362
// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
321363
// CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
322364
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
323-
324-
// Emulated masked load from alloc:
325365
// CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
326366
// CHECK: %[[C1:.+]] = arith.constant 1 : index
327367
// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
328368
// CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
329369

330370
// Select from emulated loaded vector and passthru vector:
331-
// TODO: fold this part if possible.
371+
// TODO: fold insert_strided_slice into source if possible.
332372
// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
333373
// CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
334374
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>

0 commit comments

Comments
 (0)