Skip to content

Conversation

@Groverkss
Copy link
Member

@Groverkss Groverkss commented Jul 2, 2025

Adds a folder to vector.constant_mask to fold to SplatElementsAttr when possible

@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2025

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

Adds a folder to vector.constant_mask to fold to SplatElementsAttr when possible


Full diff: https://github.com/llvm/llvm-project/pull/146724.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+20)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+32-8)
  • (modified) mlir/test/Dialect/Vector/vector-mem-transforms.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index dfb2756e57bea..ec2c87ca1cf44 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2492,6 +2492,7 @@ def Vector_ConstantMaskOp :
 
   let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
   let hasVerifier = 1;
+  let hasFolder = 1;
 }
 
 def Vector_CreateMaskOp :
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..a462b3701ddbb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6594,6 +6594,26 @@ bool ConstantMaskOp::isAllOnesMask() {
   return true;
 }
 
+static Attribute createBoolSplat(ShapedType ty, bool x) {
+  return SplatElementsAttr::get(ty, BoolAttr::get(ty.getContext(), x));
+}
+
+OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
+  ArrayRef<int64_t> bounds = getMaskDimSizes();
+  ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
+  // Check the corner case of 0-D vectors first.
+  if (vectorSizes.size() == 0) {
+    assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
+    return createBoolSplat(getVectorType(), bounds[0] == 1);
+  }
+  // Fold vector.constant_mask to splat if possible.
+  if (bounds == vectorSizes)
+    return createBoolSplat(getVectorType(), true);
+  if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
+    return createBoolSplat(getVectorType(), false);
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // CreateMaskOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
index 99d6a3dc390e0..35fd7c33e4cfe 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
@@ -23,6 +23,7 @@
 // CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-DAG:   %[[mask:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
 // CHECK:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
 // CHECK:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
@@ -31,7 +32,6 @@
 // CHECK:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
 // CHECK:       %[[boundary:.*]] = affine.apply #[[$map0]]()[%[[q]], %[[s]]]
 // CHECK:       scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] {
-// CHECK:         %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1>
 // CHECK:         %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xi32>, vector<16xi32>
 // CHECK:         %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
 // CHECK:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..8cda8d47cb908 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -14,7 +14,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
 // CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
 func.func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
   %c-1 = arith.constant -1 : index
-  // CHECK: vector.constant_mask [0] : vector<[8]xi1>
+  // CHECK: arith.constant dense<false> : vector<[8]xi1>
   %0 = vector.create_mask %c-1 : vector<[8]xi1>
   return %0 : vector<[8]xi1>
 }
@@ -36,7 +36,7 @@ func.func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>)
 func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi1>) {
   %cneg2 = arith.constant -2 : index
   %c5 = arith.constant 5 : index
-  // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+  // CHECK: arith.constant dense<false> : vector<4x3xi1>
   %0 = vector.create_mask %c5, %cneg2 : vector<4x3xi1>
   return %0 : vector<4x3xi1>
 }
@@ -47,7 +47,7 @@ func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi
 func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3xi1>) {
   %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index
-  // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+  // CHECK: arith.constant dense<false> : vector<4x3xi1>
   %0 = vector.create_mask %c0, %c2 : vector<4x3xi1>
   return %0 : vector<4x3xi1>
 }
@@ -60,7 +60,7 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
   %c16 = arith.constant 16 : index
   %0 = vector.vscale
   %1 = arith.muli %0, %c16 : index
-  // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
+  // CHECK: arith.constant dense<true> : vector<8x[16]xi1>
   %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
   return %10 : vector<8x[16]xi1>
 }
@@ -272,6 +272,30 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
 
 // -----
 
+// CHECK-LABEL: constant_mask_to_true_splat
+func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
+  // CHECK: arith.constant dense<true>
+  // CHECK-NOT: vector.constant_mask
+  %0 = vector.constant_mask [2, 4] : vector<2x4xi1>
+  return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_false_splat
+func.func @constant_mask_to_false_splat() -> vector<2x4xi1> {
+  // CHECK: arith.constant dense<false>
+  // CHECK-NOT: vector.constant_mask
+  %0 = vector.constant_mask [0, 0] : vector<2x4xi1>
+  return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_true_splat_0d
+func.func @constant_mask_to_true_splat_0d() -> vector<i1> {
+  // CHECK: arith.constant dense<true>
+  // CHECK-NOT: vector.constant_mask
+  %0 = vector.constant_mask [1] : vector<i1>
+  return %0 : vector<i1>
+}
+
 // CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
 func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
   //     CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
@@ -289,7 +313,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x2xi1>
-  // CHECK: vector.constant_mask [2, 2] : vector<2x2xi1>
+  // CHECK: arith.constant dense<true> : vector<2x2xi1>
   return %1 : vector<2x2xi1>
 }
 
@@ -322,7 +346,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x2xi1>
-  // CHECK: vector.constant_mask [0, 0] : vector<2x2xi1>
+  // CHECK: arith.constant dense<false> : vector<2x2xi1>
   return %1 : vector<2x2xi1>
 }
 
@@ -333,7 +357,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x1xi1>
-  // CHECK: vector.constant_mask [0, 0] : vector<2x1xi1>
+  // CHECK: arith.constant dense<false> : vector<2x1xi1>
   return %1 : vector<2x1xi1>
 }
 
@@ -344,7 +368,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x1xi1>
-  // CHECK: vector.constant_mask [2, 1] : vector<2x1xi1>
+  // CHECK: arith.constant dense<true> : vector<2x1xi1>
   return %1 : vector<2x1xi1>
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index 8cb25c7578495..e6593320f1bde 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -83,7 +83,7 @@ func.func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>)  {
 // CHECK-SAME:                  %[[A1:.*]]: vector<16xi32>,
 // CHECK-SAME:                  %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
 // CHECK-NEXT:      %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT:      %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK-NEXT:      %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-NEXT:      return %[[G]] : vector<16xf32>
 func.func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
@@ -112,7 +112,7 @@ func.func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru:
 // CHECK-SAME:                   %[[A1:.*]]: vector<16xi32>,
 // CHECK-SAME:                   %[[A2:.*]]: vector<16xf32>) {
 // CHECK-NEXT:      %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT:      %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK-NEXT:      vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 // CHECK-NEXT:      return
 func.func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {

@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2025

@llvm/pr-subscribers-mlir-vector

Author: Kunwar Grover (Groverkss)

Changes

Adds a folder to vector.constant_mask to fold to SplatElementsAttr when possible


Full diff: https://github.com/llvm/llvm-project/pull/146724.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+20)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+32-8)
  • (modified) mlir/test/Dialect/Vector/vector-mem-transforms.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index dfb2756e57bea..ec2c87ca1cf44 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2492,6 +2492,7 @@ def Vector_ConstantMaskOp :
 
   let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
   let hasVerifier = 1;
+  let hasFolder = 1;
 }
 
 def Vector_CreateMaskOp :
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..a462b3701ddbb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6594,6 +6594,26 @@ bool ConstantMaskOp::isAllOnesMask() {
   return true;
 }
 
+static Attribute createBoolSplat(ShapedType ty, bool x) {
+  return SplatElementsAttr::get(ty, BoolAttr::get(ty.getContext(), x));
+}
+
+OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
+  ArrayRef<int64_t> bounds = getMaskDimSizes();
+  ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
+  // Check the corner case of 0-D vectors first.
+  if (vectorSizes.size() == 0) {
+    assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
+    return createBoolSplat(getVectorType(), bounds[0] == 1);
+  }
+  // Fold vector.constant_mask to splat if possible.
+  if (bounds == vectorSizes)
+    return createBoolSplat(getVectorType(), true);
+  if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
+    return createBoolSplat(getVectorType(), false);
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // CreateMaskOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
index 99d6a3dc390e0..35fd7c33e4cfe 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
@@ -23,6 +23,7 @@
 // CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[c16:.*]] = arith.constant 16 : index
+// CHECK-DAG:   %[[mask:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
 // CHECK:       %[[a:.*]] = arith.extui %[[p]] : i32 to i64
 // CHECK:       %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
@@ -31,7 +32,6 @@
 // CHECK:       %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
 // CHECK:       %[[boundary:.*]] = affine.apply #[[$map0]]()[%[[q]], %[[s]]]
 // CHECK:       scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] {
-// CHECK:         %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1>
 // CHECK:         %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xi32>, vector<16xi32>
 // CHECK:         %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
 // CHECK:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..8cda8d47cb908 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -14,7 +14,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
 // CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
 func.func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
   %c-1 = arith.constant -1 : index
-  // CHECK: vector.constant_mask [0] : vector<[8]xi1>
+  // CHECK: arith.constant dense<false> : vector<[8]xi1>
   %0 = vector.create_mask %c-1 : vector<[8]xi1>
   return %0 : vector<[8]xi1>
 }
@@ -36,7 +36,7 @@ func.func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>)
 func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi1>) {
   %cneg2 = arith.constant -2 : index
   %c5 = arith.constant 5 : index
-  // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+  // CHECK: arith.constant dense<false> : vector<4x3xi1>
   %0 = vector.create_mask %c5, %cneg2 : vector<4x3xi1>
   return %0 : vector<4x3xi1>
 }
@@ -47,7 +47,7 @@ func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi
 func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3xi1>) {
   %c2 = arith.constant 2 : index
   %c0 = arith.constant 0 : index
-  // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+  // CHECK: arith.constant dense<false> : vector<4x3xi1>
   %0 = vector.create_mask %c0, %c2 : vector<4x3xi1>
   return %0 : vector<4x3xi1>
 }
@@ -60,7 +60,7 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
   %c16 = arith.constant 16 : index
   %0 = vector.vscale
   %1 = arith.muli %0, %c16 : index
-  // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
+  // CHECK: arith.constant dense<true> : vector<8x[16]xi1>
   %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
   return %10 : vector<8x[16]xi1>
 }
@@ -272,6 +272,30 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
 
 // -----
 
+// CHECK-LABEL: constant_mask_to_true_splat
+func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
+  // CHECK: arith.constant dense<true>
+  // CHECK-NOT: vector.constant_mask
+  %0 = vector.constant_mask [2, 4] : vector<2x4xi1>
+  return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_false_splat
+func.func @constant_mask_to_false_splat() -> vector<2x4xi1> {
+  // CHECK: arith.constant dense<false>
+  // CHECK-NOT: vector.constant_mask
+  %0 = vector.constant_mask [0, 0] : vector<2x4xi1>
+  return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_true_splat_0d
+func.func @constant_mask_to_true_splat_0d() -> vector<i1> {
+  // CHECK: arith.constant dense<true>
+  // CHECK-NOT: vector.constant_mask
+  %0 = vector.constant_mask [1] : vector<i1>
+  return %0 : vector<i1>
+}
+
 // CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
 func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
   //     CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
@@ -289,7 +313,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x2xi1>
-  // CHECK: vector.constant_mask [2, 2] : vector<2x2xi1>
+  // CHECK: arith.constant dense<true> : vector<2x2xi1>
   return %1 : vector<2x2xi1>
 }
 
@@ -322,7 +346,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x2xi1>
-  // CHECK: vector.constant_mask [0, 0] : vector<2x2xi1>
+  // CHECK: arith.constant dense<false> : vector<2x2xi1>
   return %1 : vector<2x2xi1>
 }
 
@@ -333,7 +357,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x1xi1>
-  // CHECK: vector.constant_mask [0, 0] : vector<2x1xi1>
+  // CHECK: arith.constant dense<false> : vector<2x1xi1>
   return %1 : vector<2x1xi1>
 }
 
@@ -344,7 +368,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
   %1 = vector.extract_strided_slice %0
     {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
       : vector<4x3xi1> to vector<2x1xi1>
-  // CHECK: vector.constant_mask [2, 1] : vector<2x1xi1>
+  // CHECK: arith.constant dense<true> : vector<2x1xi1>
   return %1 : vector<2x1xi1>
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index 8cb25c7578495..e6593320f1bde 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -83,7 +83,7 @@ func.func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>)  {
 // CHECK-SAME:                  %[[A1:.*]]: vector<16xi32>,
 // CHECK-SAME:                  %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
 // CHECK-NEXT:      %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT:      %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK-NEXT:      %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-NEXT:      return %[[G]] : vector<16xf32>
 func.func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
@@ -112,7 +112,7 @@ func.func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru:
 // CHECK-SAME:                   %[[A1:.*]]: vector<16xi32>,
 // CHECK-SAME:                   %[[A2:.*]]: vector<16xf32>) {
 // CHECK-NEXT:      %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT:      %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK-NEXT:      vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 // CHECK-NEXT:      return
 func.func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {

Copy link
Contributor

@fabianmcg fabianmcg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, LGTM % suggestions

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Groverkss Groverkss changed the title [mlir][Vector] Fold vector.constant_mask to splat [mlir][Vector] Fold vector.constant_mask to SplatElementsAttr Jul 2, 2025
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Minor comments

%c16 = arith.constant 16 : index
%0 = vector.vscale
%1 = arith.muli %0, %c16 : index
// CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Errr... this original output doesn't look correct to me... @banach-space? Do we have a bug somewhere else?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sizes that correspond to scalable dimensions are implicitly multiplied by vscale, though currently only zero (none set) or the size of the dim/vscale (all set) are supported

Seems correct

@Groverkss Groverkss merged commit 5eb195f into llvm:main Jul 4, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants