-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Refactor mask compression logic when emulating vector.maskedload ops
#116520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: lialan (lialan) ChangesIn the case of unaligned indexing and This patch enables multi-dimensional support by simply observing the abovementioned property and eliminate the constraints. Full diff: https://github.com/llvm/llvm-project/pull/116520.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index dc8bab325184b8..26a5f566f34948 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -125,34 +125,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
newMaskOperands);
})
- .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
- -> std::optional<Operation *> {
- ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
- size_t numMaskOperands = maskDimSizes.size();
- int64_t origIndex = maskDimSizes[numMaskOperands - 1];
- int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
- int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
- numSrcElemsPerDest);
-
- // TODO: we only want the mask between [startIndex, maskIndex]
- // to be true, the rest are false.
- if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
- return std::nullopt;
-
- SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
- newMaskDimSizes.push_back(maskIndex);
-
- if (numFrontPadElems == 0)
- return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
- newMaskDimSizes);
-
- SmallVector<bool> newMaskValues;
- for (int64_t i = 0; i < numDestElems; ++i)
- newMaskValues.push_back(i >= startIndex && i < maskIndex);
- auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
- return rewriter.create<arith::ConstantOp>(loc, newMaskType,
- newMask);
- })
+ .Case<vector::ConstantMaskOp>(
+ [&](auto constantMaskOp) -> std::optional<Operation *> {
+ ArrayRef<int64_t> maskDimSizes =
+ constantMaskOp.getMaskDimSizes();
+ size_t numMaskOperands = maskDimSizes.size();
+ int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+ int64_t maskIndex = llvm::divideCeil(
+ numFrontPadElems + origIndex, numSrcElemsPerDest);
+ SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+ newMaskDimSizes.push_back(maskIndex);
+ return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+ newMaskDimSizes);
+ })
.Case<arith::ConstantOp>([&](auto constantOp)
-> std::optional<Operation *> {
// TODO: Support multiple dimensions.
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index b1a0d4f924f3cf..73ce7ac9be2437 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -57,7 +57,7 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
// CHECK-LABEL: func @vector_cst_maskedload_i2(
// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
-// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[NEWMASK:.+]] = vector.constant_mask [2] : vector<2xi1>
// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
@@ -74,6 +74,48 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
// -----
+func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>) -> vector<5xi2> {
+ %0 = memref.alloc() : memref<4x3x5xi2>
+ %cst = arith.constant dense<0> : vector<3x5xi2>
+ %mask = vector.constant_mask [2, 2] : vector<3x5xi1>
+ %ext_mask = vector.extract %mask[1] : vector<5xi1> from vector<3x5xi1>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %1 = vector.maskedload %0[%c2, %c0, %c0], %ext_mask, %passthru :
+ memref<4x3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+ return %1 : vector<5xi2>
+}
+
+// CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
+// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<15xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
+// CHECK: %[[EXT_ORIG_MASK:.+]] = vector.extract %[[ORIG_MASK]][1]
+
+// compressed mask, used for emulated masked load
+// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
+// CHECK: %[[EXT_NEW_MASK:.+]] = vector.extract %[[NEW_MASK]][1]
+
+// Create a padded and shifted passthru vector
+// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[PADDED_PTH:.+]] = vector.insert_strided_slice %[[PASSTHRU]], %[[EMPTY]]
+// CHECK-SAME: {offsets = [2], strides = [1]}
+
+// CHECK: %[[PTH_DOWNCAST:.+]] = vector.bitcast %[[PADDED_PTH]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[C7:.+]] = arith.constant 7 : index
+// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C7]]], %[[EXT_NEW_MASK]], %[[PTH_DOWNCAST]]
+// CHECK: %[[DOWNCAST_LOAD:.+]] = vector.bitcast %[[MASKLOAD]]
+
+// pad and shift the original mask to match the size and location of the loaded value.
+// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[PADDED_MASK:.+]] = vector.insert_strided_slice %[[EXT_ORIG_MASK]], %[[EMPTY_MASK]]
+// CHECK-SAME: {offsets = [2], strides = [1]}
+// CHECK: %[[SELECT:.+]] = arith.select %[[PADDED_MASK]], %[[DOWNCAST_LOAD]], %[[PADDED_PTH]]
+// CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
+// CHECK-SAME: {offsets = [2], sizes = [5], strides = [1]}
+
+// -----
+
func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
%0 = memref.alloc() : memref<3x3xi2>
%cst = arith.constant dense<0> : vector<3x3xi2>
@@ -203,7 +245,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
-// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[ONE:.+]] = vector.constant_mask [2] : vector<2xi1>
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
// Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
@@ -268,18 +310,17 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
+
+// Emulated masked load from alloc:
// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
// CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
-
-// Emulated masked load from alloc:
// CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
// CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
-// Select from emulated loaded vector and passthru vector:
-// TODO: fold this part if possible.
+// Select from emulated loaded vector and passthru vector: (TODO: fold this part if possible)
// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
// CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
|
There's quite a bit of detail here, but the goal and justification are unclear to me. Is the intent something along these lines?
Could you clarify this statement? Phrases like "multi-dimensional support" are quite broad. From what I understand, this patch specifically adds support for extracting 1-D masks from N-D masks (where N > 1). True "multi-dimensional support" would involve directly handling full N-D masks (e.g., 2-D masks) rather than just reducing their rank. |
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Some minor suggestions inline.
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patch! IIUC, it is working because we scope the emulation on vector.maskedload with 1-D result vector. There are two "mask" in the emulation. One is used by the vector.maskedload, which ensures that we load enough bits from the memory. The other is used in final result, which selects the bits from loaded data and passthru. In this patch, we observe that we can simplify the first mask logic, which looks good to me. Please correct me if I misread something. :)
Could you clarify this statement? Phrases like "multi-dimensional support" are quite broad. From what I understand, this patch specifically adds support for extracting 1-D masks from N-D masks (where N > 1). True "multi-dimensional support" would involve directly handling full N-D masks (e.g., 2-D masks) rather than just reducing their rank.
+1 on what @banach-space said, I had the same confusion. I checked the code and found that the maskedload emulation is scoped to 1D cases:
llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Lines 605 to 608 in cab7328
| // See #115653 | |
| if (op.getVectorType().getRank() != 1) | |
| return rewriter.notifyMatchFailure(op, | |
| "only 1-D vectors are supported ATM"); |
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
5862798 to
f9cc356
Compare
vector.constant_mask to support unaligned casesvector.maskedload ops
f9cc356 to
6f09264
Compare
|
@banach-space @hanhanW Thanks for the review! I have then updated the PR description to more faithfully represent the change, along with other changes according to comments! |
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the refactoring idea, and it looks good to me. I only have a question about removing the check.
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few nits that you can ignore. Agreed with @hanhanW though that the check should be preserved.
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
…oad` ops This patch simplifies and extends the logic used when compressing masks emitted by `vector.constant_mask` to support extracting 1-D vectors from multi-dimensional vector loads. It streamlines mask computation, making it applicable for multi-dimensional mask generation, improving the overall handling of masked load operations.
6f09264 to
7dace4e
Compare
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, lgtm!
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
This patch simplifies and extends the logic used when compressing masks emitted by
vector.constant_maskto support extracting 1-D vectors from multi-dimensional vector loads. It streamlines mask computation, making it applicable for multi-dimensional mask generation, improving the overall handling of masked load operations.