diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index b547839d76738..4cd6c17e3379c 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1454,7 +1454,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> { let summary = "Convert Vector dialect to SPIR-V dialect"; let constructor = "mlir::createConvertVectorToSPIRVPass()"; - let dependentDialects = ["spirv::SPIRVDialect"]; + let dependentDialects = [ + "spirv::SPIRVDialect", + "ub::UBDialect" + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td index c439ca083e2e0..1922cc63ef353 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/Vector.td +++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td @@ -26,6 +26,15 @@ def Vector_Dialect : Dialect { // Base class for Vector dialect ops. class Vector_Op traits = []> : - Op; + Op { + + // Includes definitions for operations that support the use of poison values + // within positive index ranges. + code extraPoisonClassDeclaration = [{ + // Integer to represent a poison index within a static and positive integer + // range. + static constexpr int64_t kPoisonIndex = -1; + }]; +} #endif // MLIR_DIALECT_VECTOR_IR_VECTOR diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 4331eda166196..3b027dcfdfc70 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -469,10 +469,7 @@ def Vector_ShuffleOp ``` }]; - let extraClassDeclaration = [{ - // Integer to represent a poison value in a vector shuffle mask. - static constexpr int64_t kMaskPoisonValue = -1; - + let extraClassDeclaration = extraPoisonClassDeclaration # [{ VectorType getV1VectorType() { return ::llvm::cast(getV1().getType()); } @@ -693,9 +690,10 @@ def Vector_ExtractOp : Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at the proper position. Degenerates to an element type if n-k is zero. - Dynamic indices must be greater or equal to zero and less than the size of - the corresponding dimension. The result is undefined if any index is - out-of-bounds. + Static and dynamic indices must be greater or equal to zero and less than + the size of the corresponding dimension. The result is undefined if any + index is out-of-bounds. The value `-1` represents a poison index, which + specifies that the extracted element is poison. Example: @@ -705,9 +703,8 @@ def Vector_ExtractOp : %3 = vector.extract %1[]: vector from vector %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32> %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32> + %6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32> ``` - - TODO: Implement support for poison indices. }]; let arguments = (ins @@ -724,7 +721,7 @@ def Vector_ExtractOp : OpBuilder<(ins "Value":$source, "ArrayRef":$position)>, ]; - let extraClassDeclaration = [{ + let extraClassDeclaration = extraPoisonClassDeclaration # [{ VectorType getSourceVectorType() { return ::llvm::cast(getVector().getType()); } @@ -885,9 +882,10 @@ def Vector_InsertOp : and inserts the n-D source into the (n+k)-D destination at the proper position. Degenerates to a scalar or a 0-d vector source type when n = 0. - Dynamic indices must be greater or equal to zero and less than the size of - the corresponding dimension. The result is undefined if any index is - out-of-bounds. + Static and dynamic indices must be greater or equal to zero and less than + the size of the corresponding dimension. The result is undefined if any + index is out-of-bounds. The value `-1` represents a poison index, which + specifies that the resulting vector is poison. Example: @@ -897,9 +895,8 @@ def Vector_InsertOp : %8 = vector.insert %6, %7[] : f32 into vector %11 = vector.insert %9, %10[%a, %b, %c] : vector into vector<4x8x16xf32> %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32> + %13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32> ``` - - TODO: Implement support for poison indices. }]; let arguments = (ins @@ -917,7 +914,7 @@ def Vector_InsertOp : OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef":$position)>, ]; - let extraClassDeclaration = [{ + let extraClassDeclaration = extraPoisonClassDeclaration # [{ Type getSourceType() { return getSource().getType(); } VectorType getDestVectorType() { return ::llvm::cast(getDest().getType()); @@ -990,15 +987,13 @@ def Vector_ScalableInsertOp : ```mlir %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32> ``` - - TODO: Implement support for poison indices. }]; let assemblyFormat = [{ $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest) }]; - let extraClassDeclaration = [{ + let extraClassDeclaration = extraPoisonClassDeclaration # [{ VectorType getSourceVectorType() { return ::llvm::cast(getSource().getType()); } @@ -1043,15 +1038,13 @@ def Vector_ScalableExtractOp : ```mlir %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32> ``` - - TODO: Implement support for poison indices. }]; let assemblyFormat = [{ $source `[` $pos `]` attr-dict `:` type($res) `from` type($source) }]; - let extraClassDeclaration = [{ + let extraClassDeclaration = extraPoisonClassDeclaration # [{ VectorType getSourceVectorType() { return ::llvm::cast(getSource().getType()); } @@ -1089,8 +1082,6 @@ def Vector_InsertStridedSliceOp : {offsets = [0, 0, 2], strides = [1, 1]}: vector<2x4xf32> into vector<16x4x8xf32> ``` - - TODO: Implement support for poison indices. }]; let assemblyFormat = [{ diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index c4a8e7a81fa48..a39ab77fc8fb3 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -28,6 +28,7 @@ def Canonicalizer : Pass<"canonicalize"> { details. }]; let constructor = "mlir::createCanonicalizerPass()"; + let dependentDialects = ["ub::UBDialect"]; let options = [ Option<"topDownProcessingEnabled", "top-down", "bool", /*default=*/"true", diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp index ab9c048f56106..4481c0a497354 100644 --- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp @@ -18,7 +18,6 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" @@ -27,7 +26,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include #define DEBUG_TYPE "convert-to-spirv" diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index d3731db1ce55c..af882cb1ca6e9 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -13,7 +13,6 @@ #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp index 1932de1be603b..cc115b1d36826 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 3fbfcb4979b49..b35422f4ca3a9 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" @@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { return srcElements[posIdx]; } +// Returns `true` if `index` is either within [0, maxIndex) or equal to +// `poisonValue`. +static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, + int64_t maxIndex) { + return index == poisonValue || (index >= 0 && index < maxIndex); +} + //===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// @@ -1355,11 +1363,12 @@ LogicalResult vector::ExtractOp::verify() { for (auto [idx, pos] : llvm::enumerate(position)) { if (auto attr = dyn_cast(pos)) { int64_t constIdx = cast(attr).getInt(); - if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) { + if (!isValidPositiveIndexOrPoison( + constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) { return emitOpError("expected position attribute #") << (idx + 1) << " to be a non-negative integer smaller than the " - "corresponding vector dimension"; + "corresponding vector dimension or poison (-1)"; } } } @@ -1977,12 +1986,26 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) { return fromElementsOp.getElements()[flatIndex]; } -OpFoldResult ExtractOp::fold(FoldAdaptor) { +/// Fold an insert or extract operation into an poison value when a poison index +/// is found at any dimension of the static position. +static ub::PoisonAttr +foldPoisonIndexInsertExtractOp(MLIRContext *context, + ArrayRef staticPos, int64_t poisonVal) { + if (!llvm::is_contained(staticPos, poisonVal)) + return ub::PoisonAttr(); + + return ub::PoisonAttr::get(context); +} + +OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v. // Note: Do not fold "vector.extract %v[] : f32 from vector" (type // mismatch). if (getNumIndices() == 0 && getVector().getType() == getResult().getType()) return getVector(); + if (auto res = foldPoisonIndexInsertExtractOp( + getContext(), adaptor.getStaticPosition(), kPoisonIndex)) + return res; if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) @@ -2249,6 +2272,21 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, resultType.getNumElements())); return success(); } + +/// Fold an insert or extract operation into an poison value when a poison index +/// is found at any dimension of the static position. +template +LogicalResult +canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) { + if (auto poisonAttr = foldPoisonIndexInsertExtractOp( + op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) { + rewriter.replaceOpWithNewOp(op, op.getType(), poisonAttr); + return success(); + } + + return failure(); +} + } // namespace void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -2257,6 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context); results.add(foldExtractFromShapeCastToShapeCast); results.add(foldExtractFromFromElements); + results.add(canonicalizePoisonIndexInsertExtractOp); } static void populateFromInt64AttrArray(ArrayAttr arrayAttr, @@ -2600,7 +2639,7 @@ LogicalResult ShuffleOp::verify() { int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) + (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0)); for (auto [idx, maskPos] : llvm::enumerate(mask)) { - if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize)) + if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize)) return emitOpError("mask index #") << (idx + 1) << " out of range"; } return success(); @@ -2882,7 +2921,8 @@ LogicalResult InsertOp::verify() { for (auto [idx, pos] : llvm::enumerate(position)) { if (auto attr = pos.dyn_cast()) { int64_t constIdx = cast(attr).getInt(); - if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) { + if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex, + destVectorType.getDimSize(idx))) { return emitOpError("expected position attribute #") << (idx + 1) << " to be a non-negative integer smaller than the " @@ -3020,6 +3060,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + results.add(canonicalizePoisonIndexInsertExtractOp); } OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { @@ -3028,6 +3069,10 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { // (type mismatch). if (getNumIndices() == 0 && getSourceType() == getType()) return getSource(); + if (auto res = foldPoisonIndexInsertExtractOp( + getContext(), adaptor.getStaticPosition(), kPoisonIndex)) + return res; + return {}; } diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 058039e47313e..3a8088bccf299 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -37,4 +37,5 @@ add_mlir_library(MLIRTransforms MLIRSideEffectInterfaces MLIRSupport MLIRTransformUtils + MLIRUBDialect ) diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 5f46960507036..7ccd503fb0288 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -13,6 +13,7 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 29bed9aae5682..62649b83d887d 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1250,6 +1250,16 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 { // ----- +func.func @extract_poison_idx(%arg0: vector<16xf32>) -> f32 { + %0 = vector.extract %arg0[-1]: f32 from vector<16xf32> + return %0 : f32 +} +// CHECK-LABEL: @extract_poison_idx +// CHECK: %[[IDX:.*]] = llvm.mlir.constant(-1 : i64) : i64 +// CHECK: llvm.extractelement {{.*}}[%[[IDX]] : i64] : vector<16xf32> + +// ----- + func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 { %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32> return %0 : f32 diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index fd73cea5e4f30..383215c016039 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -175,6 +175,14 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) { // ----- +func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 { + // expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}} + %0 = vector.extract %arg0[-1] : f32 from vector<4xf32> + return %0: f32 +} + +// ----- + // CHECK-LABEL: @extract_size1_vector // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32> // CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] @@ -256,6 +264,14 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> { // ----- +func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> { + // expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}} + %1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32> + return %1: vector<4xf32> +} + +// ----- + // CHECK-LABEL: @insert_index_vector // CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32> func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 0eebb6e8d612d..f9e3b772f9f0a 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -132,6 +132,37 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index // ----- +// CHECK-LABEL: @extract_scalar_poison_idx +func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 { + // CHECK-NOT: vector.extract + // CHECK-NEXT: ub.poison : f32 + %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32> + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: @extract_vector_poison_idx +func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> { + // CHECK-NOT: vector.extract + // CHECK-NEXT: ub.poison : vector<5xf32> + %0 = vector.extract %a[-1] : vector<5xf32> from vector<4x5xf32> + return %0 : vector<5xf32> +} + +// ----- + +// CHECK-LABEL: @extract_multiple_poison_idx +func.func @extract_multiple_poison_idx(%a: vector<4x5x8xf32>) + -> vector<8xf32> { + // CHECK-NOT: vector.extract + // CHECK-NEXT: ub.poison : vector<8xf32> + %0 = vector.extract %a[-1, -1] : vector<8xf32> from vector<4x5x8xf32> + return %0 : vector<8xf32> +} + +// ----- + // CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false // CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> { @@ -2778,7 +2809,6 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector< return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector } - // ----- // CHECK-LABEL: func @vector_insert_const_regression( @@ -2792,6 +2822,39 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> { // ----- +// CHECK-LABEL: @insert_scalar_poison_idx +func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32) + -> vector<4x5xf32> { + // CHECK-NOT: vector.insert + // CHECK-NEXT: ub.poison : vector<4x5xf32> + %0 = vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32> + return %0 : vector<4x5xf32> +} + +// ----- + +// CHECK-LABEL: @insert_vector_poison_idx +func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>) + -> vector<4x5xf32> { + // CHECK-NOT: vector.insert + // CHECK-NEXT: ub.poison : vector<4x5xf32> + %0 = vector.insert %b, %a[-1] : vector<5xf32> into vector<4x5xf32> + return %0 : vector<4x5xf32> +} + +// ----- + +// CHECK-LABEL: @insert_multiple_poison_idx +func.func @insert_multiple_poison_idx(%a: vector<4x5x8xf32>, %b: vector<8xf32>) + -> vector<4x5x8xf32> { + // CHECK-NOT: vector.insert + // CHECK-NEXT: ub.poison : vector<4x5x8xf32> + %0 = vector.insert %b, %a[-1, -1] : vector<8xf32> into vector<4x5x8xf32> + return %0 : vector<4x5x8xf32> +} + +// ----- + // CHECK-LABEL: @contiguous_extract_strided_slices_to_extract // CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32> // CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 1a70791fae125..57e348c7d5991 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -186,8 +186,8 @@ func.func @extract_0d(%arg0: vector) { // ----- func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) { - // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}} - %1 = vector.extract %arg0[0, 0, -1] : f32 from vector<4x8x16xf32> + // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension or poison (-1)}} + %1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32> } // ----- @@ -247,7 +247,7 @@ func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { func.func @insert_position_overflow(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}} - %1 = vector.insert %a, %b[0, 0, -1] : f32 into vector<4x8x16xf32> + %1 = vector.insert %a, %b[0, 0, -5] : f32 into vector<4x8x16xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index cd6f3f518a1c0..67484e06f456d 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -247,6 +247,13 @@ func.func @extract_0d(%a: vector) -> f32 { return %0 : f32 } +// CHECK-LABEL: @extract_poison_idx +func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 { + // CHECK-NEXT: vector.extract %{{.*}}[-1, 0] : f32 from vector<4x5xf32> + %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32> + return %0 : f32 +} + // CHECK-LABEL: @insert_element_0d func.func @insert_element_0d(%a: f32, %b: vector) -> vector { // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector @@ -299,6 +306,13 @@ func.func @insert_0d(%a: f32, %b: vector, %c: vector<2x3xf32>) -> (vector, vector<2x3xf32> } +// CHECK-LABEL: @insert_poison_idx +func.func @insert_poison_idx(%a: vector<4x5xf32>, %b: f32) { + // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[-1, 0] : f32 into vector<4x5xf32> + vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32> + return +} + // CHECK-LABEL: @outerproduct func.func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> { // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>