-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][Vector] Fold vector.constant_mask to SplatElementsAttr #146724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesAdds a folder to vector.constant_mask to fold to SplatElementsAttr when possible Full diff: https://github.com/llvm/llvm-project/pull/146724.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index dfb2756e57bea..ec2c87ca1cf44 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2492,6 +2492,7 @@ def Vector_ConstantMaskOp :
let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
let hasVerifier = 1;
+ let hasFolder = 1;
}
def Vector_CreateMaskOp :
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..a462b3701ddbb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6594,6 +6594,26 @@ bool ConstantMaskOp::isAllOnesMask() {
return true;
}
+static Attribute createBoolSplat(ShapedType ty, bool x) {
+ return SplatElementsAttr::get(ty, BoolAttr::get(ty.getContext(), x));
+}
+
+OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
+ ArrayRef<int64_t> bounds = getMaskDimSizes();
+ ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
+ // Check the corner case of 0-D vectors first.
+ if (vectorSizes.size() == 0) {
+ assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
+ return createBoolSplat(getVectorType(), bounds[0] == 1);
+ }
+ // Fold vector.constant_mask to splat if possible.
+ if (bounds == vectorSizes)
+ return createBoolSplat(getVectorType(), true);
+ if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
+ return createBoolSplat(getVectorType(), false);
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// CreateMaskOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
index 99d6a3dc390e0..35fd7c33e4cfe 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
@@ -23,6 +23,7 @@
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[mask:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
// CHECK: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
@@ -31,7 +32,6 @@
// CHECK: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK: %[[boundary:.*]] = affine.apply #[[$map0]]()[%[[q]], %[[s]]]
// CHECK: scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] {
-// CHECK: %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1>
// CHECK: %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xi32>, vector<16xi32>
// CHECK: %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
// CHECK: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..8cda8d47cb908 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -14,7 +14,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
func.func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
%c-1 = arith.constant -1 : index
- // CHECK: vector.constant_mask [0] : vector<[8]xi1>
+ // CHECK: arith.constant dense<false> : vector<[8]xi1>
%0 = vector.create_mask %c-1 : vector<[8]xi1>
return %0 : vector<[8]xi1>
}
@@ -36,7 +36,7 @@ func.func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>)
func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi1>) {
%cneg2 = arith.constant -2 : index
%c5 = arith.constant 5 : index
- // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+ // CHECK: arith.constant dense<false> : vector<4x3xi1>
%0 = vector.create_mask %c5, %cneg2 : vector<4x3xi1>
return %0 : vector<4x3xi1>
}
@@ -47,7 +47,7 @@ func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi
func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3xi1>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
- // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+ // CHECK: arith.constant dense<false> : vector<4x3xi1>
%0 = vector.create_mask %c0, %c2 : vector<4x3xi1>
return %0 : vector<4x3xi1>
}
@@ -60,7 +60,7 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
%c16 = arith.constant 16 : index
%0 = vector.vscale
%1 = arith.muli %0, %c16 : index
- // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
+ // CHECK: arith.constant dense<true> : vector<8x[16]xi1>
%10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
return %10 : vector<8x[16]xi1>
}
@@ -272,6 +272,30 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
// -----
+// CHECK-LABEL: constant_mask_to_true_splat
+func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
+ // CHECK: arith.constant dense<true>
+ // CHECK-NOT: vector.constant_mask
+ %0 = vector.constant_mask [2, 4] : vector<2x4xi1>
+ return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_false_splat
+func.func @constant_mask_to_false_splat() -> vector<2x4xi1> {
+ // CHECK: arith.constant dense<false>
+ // CHECK-NOT: vector.constant_mask
+ %0 = vector.constant_mask [0, 0] : vector<2x4xi1>
+ return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_true_splat_0d
+func.func @constant_mask_to_true_splat_0d() -> vector<i1> {
+ // CHECK: arith.constant dense<true>
+ // CHECK-NOT: vector.constant_mask
+ %0 = vector.constant_mask [1] : vector<i1>
+ return %0 : vector<i1>
+}
+
// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
@@ -289,7 +313,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
: vector<4x3xi1> to vector<2x2xi1>
- // CHECK: vector.constant_mask [2, 2] : vector<2x2xi1>
+ // CHECK: arith.constant dense<true> : vector<2x2xi1>
return %1 : vector<2x2xi1>
}
@@ -322,7 +346,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [2, 0], sizes = [2, 2], strides = [1, 1]}
: vector<4x3xi1> to vector<2x2xi1>
- // CHECK: vector.constant_mask [0, 0] : vector<2x2xi1>
+ // CHECK: arith.constant dense<false> : vector<2x2xi1>
return %1 : vector<2x2xi1>
}
@@ -333,7 +357,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
: vector<4x3xi1> to vector<2x1xi1>
- // CHECK: vector.constant_mask [0, 0] : vector<2x1xi1>
+ // CHECK: arith.constant dense<false> : vector<2x1xi1>
return %1 : vector<2x1xi1>
}
@@ -344,7 +368,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
: vector<4x3xi1> to vector<2x1xi1>
- // CHECK: vector.constant_mask [2, 1] : vector<2x1xi1>
+ // CHECK: arith.constant dense<true> : vector<2x1xi1>
return %1 : vector<2x1xi1>
}
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index 8cb25c7578495..e6593320f1bde 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -83,7 +83,7 @@ func.func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) {
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-NEXT: return %[[G]] : vector<16xf32>
func.func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
@@ -112,7 +112,7 @@ func.func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru:
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK-NEXT: vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-NEXT: return
func.func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
|
|
@llvm/pr-subscribers-mlir-vector Author: Kunwar Grover (Groverkss) ChangesAdds a folder to vector.constant_mask to fold to SplatElementsAttr when possible Full diff: https://github.com/llvm/llvm-project/pull/146724.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index dfb2756e57bea..ec2c87ca1cf44 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2492,6 +2492,7 @@ def Vector_ConstantMaskOp :
let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
let hasVerifier = 1;
+ let hasFolder = 1;
}
def Vector_CreateMaskOp :
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..a462b3701ddbb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6594,6 +6594,26 @@ bool ConstantMaskOp::isAllOnesMask() {
return true;
}
+static Attribute createBoolSplat(ShapedType ty, bool x) {
+ return SplatElementsAttr::get(ty, BoolAttr::get(ty.getContext(), x));
+}
+
+OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
+ ArrayRef<int64_t> bounds = getMaskDimSizes();
+ ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
+ // Check the corner case of 0-D vectors first.
+ if (vectorSizes.size() == 0) {
+ assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
+ return createBoolSplat(getVectorType(), bounds[0] == 1);
+ }
+ // Fold vector.constant_mask to splat if possible.
+ if (bounds == vectorSizes)
+ return createBoolSplat(getVectorType(), true);
+ if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
+ return createBoolSplat(getVectorType(), false);
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// CreateMaskOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
index 99d6a3dc390e0..35fd7c33e4cfe 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir
@@ -23,6 +23,7 @@
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[mask:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
// CHECK: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
// CHECK: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
@@ -31,7 +32,6 @@
// CHECK: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
// CHECK: %[[boundary:.*]] = affine.apply #[[$map0]]()[%[[q]], %[[s]]]
// CHECK: scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] {
-// CHECK: %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1>
// CHECK: %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xi32>, vector<16xi32>
// CHECK: %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
// CHECK: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..8cda8d47cb908 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -14,7 +14,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
func.func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
%c-1 = arith.constant -1 : index
- // CHECK: vector.constant_mask [0] : vector<[8]xi1>
+ // CHECK: arith.constant dense<false> : vector<[8]xi1>
%0 = vector.create_mask %c-1 : vector<[8]xi1>
return %0 : vector<[8]xi1>
}
@@ -36,7 +36,7 @@ func.func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>)
func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi1>) {
%cneg2 = arith.constant -2 : index
%c5 = arith.constant 5 : index
- // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+ // CHECK: arith.constant dense<false> : vector<4x3xi1>
%0 = vector.create_mask %c5, %cneg2 : vector<4x3xi1>
return %0 : vector<4x3xi1>
}
@@ -47,7 +47,7 @@ func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi
func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3xi1>) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
- // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
+ // CHECK: arith.constant dense<false> : vector<4x3xi1>
%0 = vector.create_mask %c0, %c2 : vector<4x3xi1>
return %0 : vector<4x3xi1>
}
@@ -60,7 +60,7 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
%c16 = arith.constant 16 : index
%0 = vector.vscale
%1 = arith.muli %0, %c16 : index
- // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
+ // CHECK: arith.constant dense<true> : vector<8x[16]xi1>
%10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
return %10 : vector<8x[16]xi1>
}
@@ -272,6 +272,30 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
// -----
+// CHECK-LABEL: constant_mask_to_true_splat
+func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
+ // CHECK: arith.constant dense<true>
+ // CHECK-NOT: vector.constant_mask
+ %0 = vector.constant_mask [2, 4] : vector<2x4xi1>
+ return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_false_splat
+func.func @constant_mask_to_false_splat() -> vector<2x4xi1> {
+ // CHECK: arith.constant dense<false>
+ // CHECK-NOT: vector.constant_mask
+ %0 = vector.constant_mask [0, 0] : vector<2x4xi1>
+ return %0 : vector<2x4xi1>
+}
+
+// CHECK-LABEL: constant_mask_to_true_splat_0d
+func.func @constant_mask_to_true_splat_0d() -> vector<i1> {
+ // CHECK: arith.constant dense<true>
+ // CHECK-NOT: vector.constant_mask
+ %0 = vector.constant_mask [1] : vector<i1>
+ return %0 : vector<i1>
+}
+
// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
@@ -289,7 +313,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
: vector<4x3xi1> to vector<2x2xi1>
- // CHECK: vector.constant_mask [2, 2] : vector<2x2xi1>
+ // CHECK: arith.constant dense<true> : vector<2x2xi1>
return %1 : vector<2x2xi1>
}
@@ -322,7 +346,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [2, 0], sizes = [2, 2], strides = [1, 1]}
: vector<4x3xi1> to vector<2x2xi1>
- // CHECK: vector.constant_mask [0, 0] : vector<2x2xi1>
+ // CHECK: arith.constant dense<false> : vector<2x2xi1>
return %1 : vector<2x2xi1>
}
@@ -333,7 +357,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
: vector<4x3xi1> to vector<2x1xi1>
- // CHECK: vector.constant_mask [0, 0] : vector<2x1xi1>
+ // CHECK: arith.constant dense<false> : vector<2x1xi1>
return %1 : vector<2x1xi1>
}
@@ -344,7 +368,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
%1 = vector.extract_strided_slice %0
{offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
: vector<4x3xi1> to vector<2x1xi1>
- // CHECK: vector.constant_mask [2, 1] : vector<2x1xi1>
+ // CHECK: arith.constant dense<true> : vector<2x1xi1>
return %1 : vector<2x1xi1>
}
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index 8cb25c7578495..e6593320f1bde 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -83,7 +83,7 @@ func.func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) {
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-NEXT: return %[[G]] : vector<16xf32>
func.func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
@@ -112,7 +112,7 @@ func.func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru:
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK-NEXT: vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-NEXT: return
func.func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
|
fabianmcg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, LGTM % suggestions
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't use splat, it is on track to be removed: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks! Minor comments
| %c16 = arith.constant 16 : index | ||
| %0 = vector.vscale | ||
| %1 = arith.muli %0, %c16 : index | ||
| // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Errr... this original output doesn't look correct to me... @banach-space? Do we have a bug somewhere else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sizes that correspond to scalable dimensions are implicitly multiplied by vscale, though currently only zero (none set) or the size of the dim/vscale (all set) are supported
Seems correct
Adds a folder to vector.constant_mask to fold to SplatElementsAttr when possible