diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 30a5b06374fad..4331eda166196 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -434,10 +434,9 @@ def Vector_ShuffleOp The shuffle operation constructs a permutation (or duplication) of elements from two input vectors, returning a vector with the same element type as the input and a length that is the same as the shuffle mask. The two input - vectors must have the same element type, same rank , and trailing dimension - sizes and shuffles their values in the - leading dimension (which may differ in size) according to the given mask. - The legality rules are: + vectors must have the same element type, same rank, and trailing dimension + sizes and shuffles their values in the leading dimension (which may differ + in size) according to the given mask. The legality rules are: * the two operands must have the same element type as the result - Either, the two operands and the result must have the same rank and trailing dimension sizes, viz. given two k-D operands @@ -448,7 +447,9 @@ def Vector_ShuffleOp * the mask length equals the leading dimension size of the result * numbering the input vector indices left to right across the operands, all mask values must be within range, viz. given two k-D operands v1 and v2 - above, all mask values are in the range [0,s_1+t_1) + above, all mask values are in the range [0,s_1+t_1). The value `-1` + represents a poison mask value, which specifies that the selected element + is poison. Note, scalable vectors are not supported. @@ -463,10 +464,15 @@ def Vector_ShuffleOp : vector<2xf32>, vector<2xf32> ; yields vector<4xf32> %3 = vector.shuffle %a, %b[0, 1] : vector, vector ; yields vector<2xf32> + %4 = vector.shuffle %a, %b[0, 4, -1, -1, -1, -1] + : vector<4xf32>, vector<4xf32> ; yields vector<6xf32> ``` }]; let extraClassDeclaration = [{ + // Integer to represent a poison value in a vector shuffle mask. + static constexpr int64_t kMaskPoisonValue = -1; + VectorType getV1VectorType() { return ::llvm::cast(getV1().getType()); } @@ -700,6 +706,8 @@ def Vector_ExtractOp : %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32> %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32> ``` + + TODO: Implement support for poison indices. }]; let arguments = (ins @@ -890,6 +898,8 @@ def Vector_InsertOp : %11 = vector.insert %9, %10[%a, %b, %c] : vector into vector<4x8x16xf32> %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32> ``` + + TODO: Implement support for poison indices. }]; let arguments = (ins @@ -980,6 +990,8 @@ def Vector_ScalableInsertOp : ```mlir %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32> ``` + + TODO: Implement support for poison indices. }]; let assemblyFormat = [{ @@ -1031,6 +1043,8 @@ def Vector_ScalableExtractOp : ```mlir %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32> ``` + + TODO: Implement support for poison indices. }]; let assemblyFormat = [{ @@ -1075,6 +1089,8 @@ def Vector_InsertStridedSliceOp : {offsets = [0, 0, 2], strides = [1, 1]}: vector<2x4xf32> into vector<16x4x8xf32> ``` + + TODO: Implement support for poison indices. }]; let assemblyFormat = [{ @@ -1220,6 +1236,8 @@ def Vector_ExtractStridedSliceOp : %1 = vector.extract_strided_slice %0[0:2:1][2:4:1] vector<4x8x16xf32> to vector<2x4x16xf32> ``` + + TODO: Implement support for poison indices. }]; let builders = [ OpBuilder<(ins "Value":$source, "ArrayRef":$offsets, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ae1cf95732336..696d1e0f9b1e6 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2600,7 +2600,7 @@ LogicalResult ShuffleOp::verify() { int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) + (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0)); for (auto [idx, maskPos] : llvm::enumerate(mask)) { - if (maskPos < 0 || maskPos >= indexSize) + if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize)) return emitOpError("mask index #") << (idx + 1) << " out of range"; } return success(); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index f95e943250bd4..931cc36c9d4a8 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1105,6 +1105,16 @@ func.func @shuffle_1D_index_direct(%arg0: vector<2xindex>, %arg1: vector<2xindex // ----- +func.func @shuffle_poison_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<4xf32> { + %1 = vector.shuffle %arg0, %arg1 [0, -1, 3, -1] : vector<2xf32>, vector<2xf32> + return %1 : vector<4xf32> +} +// CHECK-LABEL: @shuffle_poison_mask( +// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>) +// CHECK: %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, -1, 3, -1] : vector<2xf32> + +// ----- + func.func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> { %1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32> return %1 : vector<5xf32> diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 103148633bf97..fd73cea5e4f30 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -613,6 +613,17 @@ func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> { // ----- +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32> +// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [1 : i32, -1 : i32, 5 : i32, -1 : i32] %[[ARG0]], %[[ARG1]] : vector<4xi32>, vector<4xi32> -> vector<4xi32> +// CHECK: return %[[SHUFFLE]] : vector<4xi32> +func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<4xi32> { + %shuffle = vector.shuffle %v0, %v1 [1, -1, 5, -1] : vector<4xi32>, vector<4xi32> + return %shuffle : vector<4xi32> +} + +// ----- + // CHECK-LABEL: func @interleave // CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>) // CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 961f1b5ffeabe..cd6f3f518a1c0 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -190,6 +190,13 @@ func.func @shuffle2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32 return %1 : vector<3x4xf32> } +// CHECK-LABEL: @shuffle_poison_mask +func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4xf32> { + // CHECK: vector.shuffle %{{.*}}, %{{.*}}[1, -1, 6, -1] : vector<4xf32>, vector<4xf32> + %1 = vector.shuffle %a, %a[1, -1, 6, -1] : vector<4xf32>, vector<4xf32> + return %1 : vector<4xf32> +} + // CHECK-LABEL: @extract_element_0d func.func @extract_element_0d(%a: vector) -> f32 { // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector