Skip to content

Conversation

@lialan
Copy link
Member

@lialan lialan commented Nov 17, 2024

This patch simplifies and extends the logic used when compressing masks emitted by vector.constant_mask to support extracting 1-D vectors from multi-dimensional vector loads. It streamlines mask computation, making it applicable for multi-dimensional mask generation, improving the overall handling of masked load operations.

@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: lialan (lialan)

Changes

In the case of unaligned indexing and vector.constant_mask, notice that numFrontPadElems is always strictly smaller than numSrcElemsPerDest, which means that with a non-zero numFrontPadElems, the compressed constant_mask op will not have any preceding zeros elements in the innermost dimemsion but the values and size relevant might change due to the extra step of shifting and aligning elements.

This patch enables multi-dimensional support by simply observing the abovementioned property and eliminate the constraints.


Full diff: https://github.com/llvm/llvm-project/pull/116520.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+13-28)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+47-6)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index dc8bab325184b8..26a5f566f34948 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -125,34 +125,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                 return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
                                                              newMaskOperands);
               })
-          .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
-                                            -> std::optional<Operation *> {
-            ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
-            size_t numMaskOperands = maskDimSizes.size();
-            int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-            int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
-            int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
-                                                 numSrcElemsPerDest);
-
-            // TODO: we only want the mask between [startIndex, maskIndex]
-            // to be true, the rest are false.
-            if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
-              return std::nullopt;
-
-            SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
-            newMaskDimSizes.push_back(maskIndex);
-
-            if (numFrontPadElems == 0)
-              return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
-                                                             newMaskDimSizes);
-
-            SmallVector<bool> newMaskValues;
-            for (int64_t i = 0; i < numDestElems; ++i)
-              newMaskValues.push_back(i >= startIndex && i < maskIndex);
-            auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
-            return rewriter.create<arith::ConstantOp>(loc, newMaskType,
-                                                      newMask);
-          })
+          .Case<vector::ConstantMaskOp>(
+              [&](auto constantMaskOp) -> std::optional<Operation *> {
+                ArrayRef<int64_t> maskDimSizes =
+                    constantMaskOp.getMaskDimSizes();
+                size_t numMaskOperands = maskDimSizes.size();
+                int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+                int64_t maskIndex = llvm::divideCeil(
+                    numFrontPadElems + origIndex, numSrcElemsPerDest);
+                SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+                newMaskDimSizes.push_back(maskIndex);
+                return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+                                                               newMaskDimSizes);
+              })
           .Case<arith::ConstantOp>([&](auto constantOp)
                                        -> std::optional<Operation *> {
             // TODO: Support multiple dimensions.
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index b1a0d4f924f3cf..73ce7ac9be2437 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -57,7 +57,7 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
 // CHECK-LABEL: func @vector_cst_maskedload_i2(
 // CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
 // CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
-// CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[NEWMASK:.+]] = vector.constant_mask [2] : vector<2xi1>
 // CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
 // CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
 // CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
@@ -74,6 +74,48 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
 
 // -----
 
+func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>) -> vector<5xi2> {
+  %0 = memref.alloc() : memref<4x3x5xi2>
+  %cst = arith.constant dense<0> : vector<3x5xi2>
+  %mask = vector.constant_mask [2, 2] : vector<3x5xi1>
+  %ext_mask = vector.extract %mask[1] : vector<5xi1> from vector<3x5xi1>
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %1 = vector.maskedload %0[%c2, %c0, %c0], %ext_mask, %passthru :
+    memref<4x3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+  return %1 : vector<5xi2>
+}
+
+// CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
+// CHECK-SAME:   %[[PASSTHRU:[a-zA-Z0-9]+]]
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<15xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
+// CHECK: %[[EXT_ORIG_MASK:.+]] = vector.extract %[[ORIG_MASK]][1]
+
+// compressed mask, used for emulated masked load
+// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
+// CHECK: %[[EXT_NEW_MASK:.+]] = vector.extract %[[NEW_MASK]][1]
+
+// Create a padded and shifted passthru vector
+// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[PADDED_PTH:.+]] = vector.insert_strided_slice %[[PASSTHRU]], %[[EMPTY]]
+// CHECK-SAME: {offsets = [2], strides = [1]} 
+
+// CHECK: %[[PTH_DOWNCAST:.+]] = vector.bitcast %[[PADDED_PTH]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[C7:.+]] = arith.constant 7 : index
+// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C7]]], %[[EXT_NEW_MASK]], %[[PTH_DOWNCAST]]
+// CHECK: %[[DOWNCAST_LOAD:.+]] = vector.bitcast %[[MASKLOAD]]
+
+// pad and shift the original mask to match the size and location of the loaded value.
+// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[PADDED_MASK:.+]] = vector.insert_strided_slice %[[EXT_ORIG_MASK]], %[[EMPTY_MASK]]
+// CHECK-SAME: {offsets = [2], strides = [1]} 
+// CHECK: %[[SELECT:.+]] = arith.select %[[PADDED_MASK]], %[[DOWNCAST_LOAD]], %[[PADDED_PTH]]
+// CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
+// CHECK-SAME: {offsets = [2], sizes = [5], strides = [1]}
+
+// -----
+
 func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
   %0 = memref.alloc() : memref<3x3xi2>
   %cst = arith.constant dense<0> : vector<3x3xi2>
@@ -203,7 +245,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
 // CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
 // CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
 // CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
-// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[ONE:.+]] = vector.constant_mask [2] : vector<2xi1>
 // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
 
 // Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
@@ -268,18 +310,17 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
 // CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
 
 // CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
+
+// Emulated masked load from alloc:
 // CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
 // CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
 // CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
-
-// Emulated masked load from alloc:
 // CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
 // CHECK: %[[C1:.+]] = arith.constant 1 : index
 // CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
 // CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
 
-// Select from emulated loaded vector and passthru vector:
-// TODO: fold this part if possible.
+// Select from emulated loaded vector and passthru vector: (TODO: fold this part if possible)
 // CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
 // CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
 // CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>

@banach-space
Copy link
Contributor

In the case of unaligned indexing and vector.constant_mask, notice that numFrontPadElems is always strictly smaller than numSrcElemsPerDest, which means that with a non-zero numFrontPadElems, the compressed constant_mask op will not have any preceding zeros elements in the innermost dimemsion but the values and size relevant might change due to the extra step of shifting and aligning elements.

There's quite a bit of detail here, but the goal and justification are unclear to me. Is the intent something along these lines?

Ensure that vector.constant_mask is always used when creating a constant mask (instead of requiring arith.constant).
This simplifies the logic and facilitates support for cases where the original mask is rank N, with N > 1.

This patch enables multi-dimensional support by simply observing the abovementioned property and eliminating the constraints.

Could you clarify this statement? Phrases like "multi-dimensional support" are quite broad. From what I understand, this patch specifically adds support for extracting 1-D masks from N-D masks (where N > 1). True "multi-dimensional support" would involve directly handling full N-D masks (e.g., 2-D masks) rather than just reducing their rank.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Some minor suggestions inline.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patch! IIUC, it is working because we scope the emulation on vector.maskedload with 1-D result vector. There are two "mask" in the emulation. One is used by the vector.maskedload, which ensures that we load enough bits from the memory. The other is used in final result, which selects the bits from loaded data and passthru. In this patch, we observe that we can simplify the first mask logic, which looks good to me. Please correct me if I misread something. :)

Could you clarify this statement? Phrases like "multi-dimensional support" are quite broad. From what I understand, this patch specifically adds support for extracting 1-D masks from N-D masks (where N > 1). True "multi-dimensional support" would involve directly handling full N-D masks (e.g., 2-D masks) rather than just reducing their rank.

+1 on what @banach-space said, I had the same confusion. I checked the code and found that the maskedload emulation is scoped to 1D cases:

// See #115653
if (op.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");

@lialan lialan force-pushed the lialan/constant_mask_improvement branch from 5862798 to f9cc356 Compare November 21, 2024 02:05
@lialan lialan changed the title [MLIR] vector.constant_mask to support unaligned cases [MLIR] Refactor mask compression logic when emulating vector.maskedload ops Nov 21, 2024
@lialan lialan force-pushed the lialan/constant_mask_improvement branch from f9cc356 to 6f09264 Compare November 21, 2024 02:41
@lialan
Copy link
Member Author

lialan commented Nov 21, 2024

@banach-space @hanhanW Thanks for the review! I have then updated the PR description to more faithfully represent the change, along with other changes according to comments!

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the refactoring idea, and it looks good to me. I only have a question about removing the check.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few nits that you can ignore. Agreed with @hanhanW though that the check should be preserved.

…oad` ops

This patch simplifies and extends the logic used when compressing masks emitted by `vector.constant_mask` to support extracting 1-D vectors from multi-dimensional vector loads. It streamlines mask computation, making it applicable for multi-dimensional mask generation, improving the overall handling of masked load operations.
@lialan lialan force-pushed the lialan/constant_mask_improvement branch from 6f09264 to 7dace4e Compare November 26, 2024 02:16
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, lgtm!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@hanhanW hanhanW merged commit 1669ac4 into llvm:main Nov 27, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants