From 7deabd475b2feb715a70a8866fc2fc546a609caf Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Sat, 21 Dec 2024 19:06:58 -0800 Subject: [PATCH 1/2] [mlir][Vector] Support poison in `vector.shuffle` mask This PR extends the existing poison support in https://mlir.llvm.org/docs/Dialects/UBOps/ by representing poison mask values in `vector.shuffle`. Similar to LLVM (see https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/Instructions.h#L1884) this requires defining an integer value (`-1`) representing poison in the `vector.shuffle` mask. The current implementation parses and prints `-1` for the poison value. I implemented a custom parser/printer to use the `poison` keyword instead but I think it's an overkill to have to introduce a hand-written parsers/printers for every operation supporting poison. I also explored adding new flavors of `DenseIXArrayAttr` that could take an argument to represent the poison value, but I also desisted as the resulting code was too complex. Happy to get feedback about this and improve the assembly format as a follow-up. --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 10 ++++++++-- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +- mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 10 ++++++++++ .../Conversion/VectorToSPIRV/vector-to-spirv.mlir | 11 +++++++++++ mlir/test/Dialect/Vector/ops.mlir | 7 +++++++ 5 files changed, 37 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 30a5b06374fad..a786e4696415c 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -434,7 +434,7 @@ def Vector_ShuffleOp The shuffle operation constructs a permutation (or duplication) of elements from two input vectors, returning a vector with the same element type as the input and a length that is the same as the shuffle mask. The two input - vectors must have the same element type, same rank , and trailing dimension + vectors must have the same element type, same rank, and trailing dimension sizes and shuffles their values in the leading dimension (which may differ in size) according to the given mask. The legality rules are: @@ -448,7 +448,8 @@ def Vector_ShuffleOp * the mask length equals the leading dimension size of the result * numbering the input vector indices left to right across the operands, all mask values must be within range, viz. given two k-D operands v1 and v2 - above, all mask values are in the range [0,s_1+t_1) + above, all mask values are in the range [0,s_1+t_1). -1 is used to + represent a poison mask value. Note, scalable vectors are not supported. @@ -463,10 +464,15 @@ def Vector_ShuffleOp : vector<2xf32>, vector<2xf32> ; yields vector<4xf32> %3 = vector.shuffle %a, %b[0, 1] : vector, vector ; yields vector<2xf32> + %4 = vector.shuffle %a, %b[0, 4, -1, -1, -1, -1] + : vector<4xf32>, vector<4xf32> ; yields vector<6xf32> ``` }]; let extraClassDeclaration = [{ + // Integer to represent a poison value in a vector shuffle mask. + static constexpr int64_t kMaskPoisonValue = -1; + VectorType getV1VectorType() { return ::llvm::cast(getV1().getType()); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ae1cf95732336..696d1e0f9b1e6 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2600,7 +2600,7 @@ LogicalResult ShuffleOp::verify() { int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) + (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0)); for (auto [idx, maskPos] : llvm::enumerate(mask)) { - if (maskPos < 0 || maskPos >= indexSize) + if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize)) return emitOpError("mask index #") << (idx + 1) << " out of range"; } return success(); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index f95e943250bd4..931cc36c9d4a8 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1105,6 +1105,16 @@ func.func @shuffle_1D_index_direct(%arg0: vector<2xindex>, %arg1: vector<2xindex // ----- +func.func @shuffle_poison_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<4xf32> { + %1 = vector.shuffle %arg0, %arg1 [0, -1, 3, -1] : vector<2xf32>, vector<2xf32> + return %1 : vector<4xf32> +} +// CHECK-LABEL: @shuffle_poison_mask( +// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>) +// CHECK: %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, -1, 3, -1] : vector<2xf32> + +// ----- + func.func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> { %1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32> return %1 : vector<5xf32> diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 103148633bf97..fd73cea5e4f30 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -613,6 +613,17 @@ func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> { // ----- +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32> +// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [1 : i32, -1 : i32, 5 : i32, -1 : i32] %[[ARG0]], %[[ARG1]] : vector<4xi32>, vector<4xi32> -> vector<4xi32> +// CHECK: return %[[SHUFFLE]] : vector<4xi32> +func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<4xi32> { + %shuffle = vector.shuffle %v0, %v1 [1, -1, 5, -1] : vector<4xi32>, vector<4xi32> + return %shuffle : vector<4xi32> +} + +// ----- + // CHECK-LABEL: func @interleave // CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>) // CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 961f1b5ffeabe..cd6f3f518a1c0 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -190,6 +190,13 @@ func.func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32 return %1 : vector<3x4xf32> } +// CHECK-LABEL: @shuffle_poison_mask +func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4xf32> { + // CHECK: vector.shuffle %{{.*}}, %{{.*}}[1, -1, 6, -1] : vector<4xf32>, vector<4xf32> + %1 = vector.shuffle %a, %a[1, -1, 6, -1] : vector<4xf32>, vector<4xf32> + return %1 : vector<4xf32> +} + // CHECK-LABEL: @extract_element_0d func.func @extract_element_0d(%a: vector) -> f32 { // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector From dfd0c80ca723df9e3147454d71337e354223ff6f Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Fri, 17 Jan 2025 14:03:23 -0800 Subject: [PATCH 2/2] Feedback --- .../mlir/Dialect/Vector/IR/VectorOps.td | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index a786e4696415c..4331eda166196 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -435,9 +435,8 @@ def Vector_ShuffleOp from two input vectors, returning a vector with the same element type as the input and a length that is the same as the shuffle mask. The two input vectors must have the same element type, same rank, and trailing dimension - sizes and shuffles their values in the - leading dimension (which may differ in size) according to the given mask. - The legality rules are: + sizes and shuffles their values in the leading dimension (which may differ + in size) according to the given mask. The legality rules are: * the two operands must have the same element type as the result - Either, the two operands and the result must have the same rank and trailing dimension sizes, viz. given two k-D operands @@ -448,8 +447,9 @@ def Vector_ShuffleOp * the mask length equals the leading dimension size of the result * numbering the input vector indices left to right across the operands, all mask values must be within range, viz. given two k-D operands v1 and v2 - above, all mask values are in the range [0,s_1+t_1). -1 is used to - represent a poison mask value. + above, all mask values are in the range [0,s_1+t_1). The value `-1` + represents a poison mask value, which specifies that the selected element + is poison. Note, scalable vectors are not supported. @@ -706,6 +706,8 @@ def Vector_ExtractOp : %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32> %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32> ``` + + TODO: Implement support for poison indices. }]; let arguments = (ins @@ -896,6 +898,8 @@ def Vector_InsertOp : %11 = vector.insert %9, %10[%a, %b, %c] : vector into vector<4x8x16xf32> %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32> ``` + + TODO: Implement support for poison indices. }]; let arguments = (ins @@ -986,6 +990,8 @@ def Vector_ScalableInsertOp : ```mlir %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32> ``` + + TODO: Implement support for poison indices. }]; let assemblyFormat = [{ @@ -1037,6 +1043,8 @@ def Vector_ScalableExtractOp : ```mlir %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32> ``` + + TODO: Implement support for poison indices. }]; let assemblyFormat = [{ @@ -1081,6 +1089,8 @@ def Vector_InsertStridedSliceOp : {offsets = [0, 0, 2], strides = [1, 1]}: vector<2x4xf32> into vector<16x4x8xf32> ``` + + TODO: Implement support for poison indices. }]; let assemblyFormat = [{ @@ -1226,6 +1236,8 @@ def Vector_ExtractStridedSliceOp : %1 = vector.extract_strided_slice %0[0:2:1][2:4:1] vector<4x8x16xf32> to vector<2x4x16xf32> ``` + + TODO: Implement support for poison indices. }]; let builders = [ OpBuilder<(ins "Value":$source, "ArrayRef":$offsets,