Skip to content

Commit 5862798

Browse files
committed
[MLIR] vector.constant_mask to support unaligned cases
In the case of unaligned indexing and `vector.constant_mask`, notice that `numFrontPadElems` is always strictly smaller than `numSrcElemsPerDest`, which means that with a non-zero `numFrontPadElems`, the compressed `constant_mask` op will not have any preceding zeros elements in the innermost dimemsion but the values and size relevant might change due to the extra step of shifting and aligning elements. This patch enables multi-dimensional support by simply observing the abovementioned property and eliminate the constraints.
1 parent ec353b7 commit 5862798

File tree

2 files changed

+60
-34
lines changed

2 files changed

+60
-34
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -125,34 +125,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
125125
return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
126126
newMaskOperands);
127127
})
128-
.Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
129-
-> std::optional<Operation *> {
130-
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
131-
size_t numMaskOperands = maskDimSizes.size();
132-
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
133-
int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
134-
int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
135-
numSrcElemsPerDest);
136-
137-
// TODO: we only want the mask between [startIndex, maskIndex]
138-
// to be true, the rest are false.
139-
if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
140-
return std::nullopt;
141-
142-
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
143-
newMaskDimSizes.push_back(maskIndex);
144-
145-
if (numFrontPadElems == 0)
146-
return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
147-
newMaskDimSizes);
148-
149-
SmallVector<bool> newMaskValues;
150-
for (int64_t i = 0; i < numDestElems; ++i)
151-
newMaskValues.push_back(i >= startIndex && i < maskIndex);
152-
auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
153-
return rewriter.create<arith::ConstantOp>(loc, newMaskType,
154-
newMask);
155-
})
128+
.Case<vector::ConstantMaskOp>(
129+
[&](auto constantMaskOp) -> std::optional<Operation *> {
130+
ArrayRef<int64_t> maskDimSizes =
131+
constantMaskOp.getMaskDimSizes();
132+
size_t numMaskOperands = maskDimSizes.size();
133+
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
134+
int64_t maskIndex = llvm::divideCeil(
135+
numFrontPadElems + origIndex, numSrcElemsPerDest);
136+
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
137+
newMaskDimSizes.push_back(maskIndex);
138+
return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
139+
newMaskDimSizes);
140+
})
156141
.Case<arith::ConstantOp>([&](auto constantOp)
157142
-> std::optional<Operation *> {
158143
// TODO: Support multiple dimensions.

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
5757
// CHECK-LABEL: func @vector_cst_maskedload_i2(
5858
// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
5959
// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
60-
// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
60+
// CHECK: %[[NEWMASK:.+]] = vector.constant_mask [2] : vector<2xi1>
6161
// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
6262
// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
6363
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
@@ -74,6 +74,48 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
7474

7575
// -----
7676

77+
func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>) -> vector<5xi2> {
78+
%0 = memref.alloc() : memref<4x3x5xi2>
79+
%cst = arith.constant dense<0> : vector<3x5xi2>
80+
%mask = vector.constant_mask [2, 2] : vector<3x5xi1>
81+
%ext_mask = vector.extract %mask[1] : vector<5xi1> from vector<3x5xi1>
82+
%c0 = arith.constant 0 : index
83+
%c2 = arith.constant 2 : index
84+
%1 = vector.maskedload %0[%c2, %c0, %c0], %ext_mask, %passthru :
85+
memref<4x3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
86+
return %1 : vector<5xi2>
87+
}
88+
89+
// CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
90+
// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]
91+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<15xi8>
92+
// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
93+
// CHECK: %[[EXT_ORIG_MASK:.+]] = vector.extract %[[ORIG_MASK]][1]
94+
95+
// compressed mask, used for emulated masked load
96+
// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
97+
// CHECK: %[[EXT_NEW_MASK:.+]] = vector.extract %[[NEW_MASK]][1]
98+
99+
// Create a padded and shifted passthru vector
100+
// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
101+
// CHECK: %[[PADDED_PTH:.+]] = vector.insert_strided_slice %[[PASSTHRU]], %[[EMPTY]]
102+
// CHECK-SAME: {offsets = [2], strides = [1]}
103+
104+
// CHECK: %[[PTH_DOWNCAST:.+]] = vector.bitcast %[[PADDED_PTH]] : vector<8xi2> to vector<2xi8>
105+
// CHECK: %[[C7:.+]] = arith.constant 7 : index
106+
// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C7]]], %[[EXT_NEW_MASK]], %[[PTH_DOWNCAST]]
107+
// CHECK: %[[DOWNCAST_LOAD:.+]] = vector.bitcast %[[MASKLOAD]]
108+
109+
// pad and shift the original mask to match the size and location of the loaded value.
110+
// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
111+
// CHECK: %[[PADDED_MASK:.+]] = vector.insert_strided_slice %[[EXT_ORIG_MASK]], %[[EMPTY_MASK]]
112+
// CHECK-SAME: {offsets = [2], strides = [1]}
113+
// CHECK: %[[SELECT:.+]] = arith.select %[[PADDED_MASK]], %[[DOWNCAST_LOAD]], %[[PADDED_PTH]]
114+
// CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
115+
// CHECK-SAME: {offsets = [2], sizes = [5], strides = [1]}
116+
117+
// -----
118+
77119
func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
78120
%0 = memref.alloc() : memref<3x3xi2>
79121
%cst = arith.constant dense<0> : vector<3x3xi2>
@@ -203,7 +245,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
203245
// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
204246
// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
205247
// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
206-
// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
248+
// CHECK: %[[ONE:.+]] = vector.constant_mask [2] : vector<2xi1>
207249
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
208250

209251
// Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
@@ -268,18 +310,17 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
268310
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
269311

270312
// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
313+
314+
// Emulated masked load from alloc:
271315
// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
272316
// CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
273317
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
274-
275-
// Emulated masked load from alloc:
276318
// CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
277319
// CHECK: %[[C1:.+]] = arith.constant 1 : index
278320
// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
279321
// CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
280322

281-
// Select from emulated loaded vector and passthru vector:
282-
// TODO: fold this part if possible.
323+
// Select from emulated loaded vector and passthru vector: (TODO: fold this part if possible)
283324
// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
284325
// CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
285326
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>

0 commit comments

Comments
 (0)