From 4a9ae1caa05042a15f22c36883e78b96c45d920a Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Sat, 18 Jan 2025 15:21:17 -0800 Subject: [PATCH 1/3] [mlir][Vector] Add support for poison indices to `Extract/IndexOp` Following up on #122188, this PR adds support for poison indices to `ExtractOp` and `InsertOp`. It also includes canonicalization patterns to turn extract/insert ops with poison indices into `ub.poison`. --- mlir/include/mlir/Dialect/Vector/IR/Vector.td | 11 ++++- .../mlir/Dialect/Vector/IR/VectorOps.td | 23 +++------- mlir/include/mlir/Transforms/Passes.td | 1 + mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 35 +++++++++++++-- mlir/lib/Transforms/CMakeLists.txt | 1 + mlir/lib/Transforms/Canonicalizer.cpp | 1 + mlir/test/Dialect/Vector/canonicalize.mlir | 43 ++++++++++++++++++- mlir/test/Dialect/Vector/invalid.mlir | 4 +- mlir/test/Dialect/Vector/ops.mlir | 14 ++++++ 9 files changed, 108 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td index c439ca083e2e0..1922cc63ef353 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/Vector.td +++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td @@ -26,6 +26,15 @@ def Vector_Dialect : Dialect { // Base class for Vector dialect ops. class Vector_Op traits = []> : - Op; + Op { + + // Includes definitions for operations that support the use of poison values + // within positive index ranges. + code extraPoisonClassDeclaration = [{ + // Integer to represent a poison index within a static and positive integer + // range. + static constexpr int64_t kPoisonIndex = -1; + }]; +} #endif // MLIR_DIALECT_VECTOR_IR_VECTOR diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 4331eda166196..c57e3dd13233c 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -469,10 +469,7 @@ def Vector_ShuffleOp ``` }]; - let extraClassDeclaration = [{ - // Integer to represent a poison value in a vector shuffle mask. - static constexpr int64_t kMaskPoisonValue = -1; - + let extraClassDeclaration = extraPoisonClassDeclaration # [{ VectorType getV1VectorType() { return ::llvm::cast(getV1().getType()); } @@ -706,8 +703,6 @@ def Vector_ExtractOp : %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32> %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32> ``` - - TODO: Implement support for poison indices. }]; let arguments = (ins @@ -724,7 +719,7 @@ def Vector_ExtractOp : OpBuilder<(ins "Value":$source, "ArrayRef":$position)>, ]; - let extraClassDeclaration = [{ + let extraClassDeclaration = extraPoisonClassDeclaration # [{ VectorType getSourceVectorType() { return ::llvm::cast(getVector().getType()); } @@ -898,8 +893,6 @@ def Vector_InsertOp : %11 = vector.insert %9, %10[%a, %b, %c] : vector into vector<4x8x16xf32> %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32> ``` - - TODO: Implement support for poison indices. }]; let arguments = (ins @@ -917,7 +910,7 @@ def Vector_InsertOp : OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef":$position)>, ]; - let extraClassDeclaration = [{ + let extraClassDeclaration = extraPoisonClassDeclaration # [{ Type getSourceType() { return getSource().getType(); } VectorType getDestVectorType() { return ::llvm::cast(getDest().getType()); @@ -990,15 +983,13 @@ def Vector_ScalableInsertOp : ```mlir %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32> ``` - - TODO: Implement support for poison indices. }]; let assemblyFormat = [{ $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest) }]; - let extraClassDeclaration = [{ + let extraClassDeclaration = extraPoisonClassDeclaration # [{ VectorType getSourceVectorType() { return ::llvm::cast(getSource().getType()); } @@ -1043,15 +1034,13 @@ def Vector_ScalableExtractOp : ```mlir %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32> ``` - - TODO: Implement support for poison indices. }]; let assemblyFormat = [{ $source `[` $pos `]` attr-dict `:` type($res) `from` type($source) }]; - let extraClassDeclaration = [{ + let extraClassDeclaration = extraPoisonClassDeclaration # [{ VectorType getSourceVectorType() { return ::llvm::cast(getSource().getType()); } @@ -1089,8 +1078,6 @@ def Vector_InsertStridedSliceOp : {offsets = [0, 0, 2], strides = [1, 1]}: vector<2x4xf32> into vector<16x4x8xf32> ``` - - TODO: Implement support for poison indices. }]; let assemblyFormat = [{ diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index c4a8e7a81fa48..a39ab77fc8fb3 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -28,6 +28,7 @@ def Canonicalizer : Pass<"canonicalize"> { details. }]; let constructor = "mlir::createCanonicalizerPass()"; + let dependentDialects = ["ub::UBDialect"]; let options = [ Option<"topDownProcessingEnabled", "top-down", "bool", /*default=*/"true", diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 3fbfcb4979b49..eb732df949545 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" @@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { return srcElements[posIdx]; } +// Returns `true` if `index` is either within [0, maxIndex) or equal to +// `poisonValue`. +static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, + int64_t maxIndex) { + return index == poisonValue || (index >= 0 && index < maxIndex); +} + //===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// @@ -1355,7 +1363,8 @@ LogicalResult vector::ExtractOp::verify() { for (auto [idx, pos] : llvm::enumerate(position)) { if (auto attr = dyn_cast(pos)) { int64_t constIdx = cast(attr).getInt(); - if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) { + if (!isValidPositiveIndexOrPoison( + constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) { return emitOpError("expected position attribute #") << (idx + 1) << " to be a non-negative integer smaller than the " @@ -2249,6 +2258,23 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, resultType.getNumElements())); return success(); } + +/// Fold an insert or extract operation into an poison value when a poison index +/// is found at any dimension of the static position. +template +LogicalResult foldPoisonIndexInsertExtractOp(OpTy op, + PatternRewriter &rewriter) { + auto hasPoisonIndex = [](int64_t index) { + return index == OpTy::kPoisonIndex; + }; + + if (llvm::none_of(op.getStaticPosition(), hasPoisonIndex)) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getResult().getType()); + return success(); +} + } // namespace void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -2257,6 +2283,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context); results.add(foldExtractFromShapeCastToShapeCast); results.add(foldExtractFromFromElements); + results.add(foldPoisonIndexInsertExtractOp); } static void populateFromInt64AttrArray(ArrayAttr arrayAttr, @@ -2600,7 +2627,7 @@ LogicalResult ShuffleOp::verify() { int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) + (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0)); for (auto [idx, maskPos] : llvm::enumerate(mask)) { - if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize)) + if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize)) return emitOpError("mask index #") << (idx + 1) << " out of range"; } return success(); @@ -2882,7 +2909,8 @@ LogicalResult InsertOp::verify() { for (auto [idx, pos] : llvm::enumerate(position)) { if (auto attr = pos.dyn_cast()) { int64_t constIdx = cast(attr).getInt(); - if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) { + if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex, + destVectorType.getDimSize(idx))) { return emitOpError("expected position attribute #") << (idx + 1) << " to be a non-negative integer smaller than the " @@ -3020,6 +3048,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + results.add(foldPoisonIndexInsertExtractOp); } OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 058039e47313e..3a8088bccf299 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -37,4 +37,5 @@ add_mlir_library(MLIRTransforms MLIRSideEffectInterfaces MLIRSupport MLIRTransformUtils + MLIRUBDialect ) diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 5f46960507036..7ccd503fb0288 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -13,6 +13,7 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 0eebb6e8d612d..e6f9630d0449a 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -132,6 +132,26 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index // ----- +// CHECK-LABEL: @extract_scalar_poison_idx +func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 { + // CHECK-NOT: vector.extract + // CHECK-NEXT: ub.poison : f32 + %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32> + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: @extract_vector_poison_idx +func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> { + // CHECK-NOT: vector.extract + // CHECK-NEXT: ub.poison : vector<5xf32> + %0 = vector.extract %a[-1] : vector<5xf32> from vector<4x5xf32> + return %0 : vector<5xf32> +} + +// ----- + // CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false // CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> { @@ -2778,7 +2798,6 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector< return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector } - // ----- // CHECK-LABEL: func @vector_insert_const_regression( @@ -2792,6 +2811,28 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> { // ----- +// CHECK-LABEL: @insert_scalar_poison_idx +func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32) + -> vector<4x5xf32> { + // CHECK-NOT: vector.insert + // CHECK-NEXT: ub.poison : vector<4x5xf32> + %0 = vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32> + return %0 : vector<4x5xf32> +} + +// ----- + +// CHECK-LABEL: @insert_vector_poison_idx +func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>) + -> vector<4x5xf32> { + // CHECK-NOT: vector.insert + // CHECK-NEXT: ub.poison : vector<4x5xf32> + %0 = vector.insert %b, %a[-1] : vector<5xf32> into vector<4x5xf32> + return %0 : vector<4x5xf32> +} + +// ----- + // CHECK-LABEL: @contiguous_extract_strided_slices_to_extract // CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32> // CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 1a70791fae125..9416f4787eefb 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -187,7 +187,7 @@ func.func @extract_0d(%arg0: vector) { func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}} - %1 = vector.extract %arg0[0, 0, -1] : f32 from vector<4x8x16xf32> + %1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32> } // ----- @@ -247,7 +247,7 @@ func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { func.func @insert_position_overflow(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding dest vector dimension}} - %1 = vector.insert %a, %b[0, 0, -1] : f32 into vector<4x8x16xf32> + %1 = vector.insert %a, %b[0, 0, -5] : f32 into vector<4x8x16xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index cd6f3f518a1c0..67484e06f456d 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -247,6 +247,13 @@ func.func @extract_0d(%a: vector) -> f32 { return %0 : f32 } +// CHECK-LABEL: @extract_poison_idx +func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 { + // CHECK-NEXT: vector.extract %{{.*}}[-1, 0] : f32 from vector<4x5xf32> + %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32> + return %0 : f32 +} + // CHECK-LABEL: @insert_element_0d func.func @insert_element_0d(%a: f32, %b: vector) -> vector { // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector @@ -299,6 +306,13 @@ func.func @insert_0d(%a: f32, %b: vector, %c: vector<2x3xf32>) -> (vector, vector<2x3xf32> } +// CHECK-LABEL: @insert_poison_idx +func.func @insert_poison_idx(%a: vector<4x5xf32>, %b: f32) { + // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[-1, 0] : f32 into vector<4x5xf32> + vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32> + return +} + // CHECK-LABEL: @outerproduct func.func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> { // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32> From 0d83b2095c57cf7390b8d88db3977e9a30f0cd48 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Tue, 21 Jan 2025 21:57:11 -0800 Subject: [PATCH 2/3] Feedback --- .../mlir/Dialect/Vector/IR/VectorOps.td | 16 +++++++++----- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 ++----- .../VectorToLLVM/vector-to-llvm.mlir | 10 +++++++++ .../VectorToSPIRV/vector-to-spirv.mlir | 16 ++++++++++++++ mlir/test/Dialect/Vector/canonicalize.mlir | 22 +++++++++++++++++++ mlir/test/Dialect/Vector/invalid.mlir | 2 +- 6 files changed, 61 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index c57e3dd13233c..3b027dcfdfc70 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -690,9 +690,10 @@ def Vector_ExtractOp : Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at the proper position. Degenerates to an element type if n-k is zero. - Dynamic indices must be greater or equal to zero and less than the size of - the corresponding dimension. The result is undefined if any index is - out-of-bounds. + Static and dynamic indices must be greater or equal to zero and less than + the size of the corresponding dimension. The result is undefined if any + index is out-of-bounds. The value `-1` represents a poison index, which + specifies that the extracted element is poison. Example: @@ -702,6 +703,7 @@ def Vector_ExtractOp : %3 = vector.extract %1[]: vector from vector %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32> %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32> + %6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32> ``` }]; @@ -880,9 +882,10 @@ def Vector_InsertOp : and inserts the n-D source into the (n+k)-D destination at the proper position. Degenerates to a scalar or a 0-d vector source type when n = 0. - Dynamic indices must be greater or equal to zero and less than the size of - the corresponding dimension. The result is undefined if any index is - out-of-bounds. + Static and dynamic indices must be greater or equal to zero and less than + the size of the corresponding dimension. The result is undefined if any + index is out-of-bounds. The value `-1` represents a poison index, which + specifies that the resulting vector is poison. Example: @@ -892,6 +895,7 @@ def Vector_InsertOp : %8 = vector.insert %6, %7[] : f32 into vector %11 = vector.insert %9, %10[%a, %b, %c] : vector into vector<4x8x16xf32> %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32> + %13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32> ``` }]; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index eb732df949545..58ab634c74a4a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1368,7 +1368,7 @@ LogicalResult vector::ExtractOp::verify() { return emitOpError("expected position attribute #") << (idx + 1) << " to be a non-negative integer smaller than the " - "corresponding vector dimension"; + "corresponding vector dimension or poison (-1)"; } } } @@ -2264,11 +2264,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, template LogicalResult foldPoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) { - auto hasPoisonIndex = [](int64_t index) { - return index == OpTy::kPoisonIndex; - }; - - if (llvm::none_of(op.getStaticPosition(), hasPoisonIndex)) + if (!llvm::is_contained(op.getStaticPosition(), OpTy::kPoisonIndex)) return failure(); rewriter.replaceOpWithNewOp(op, op.getResult().getType()); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 29bed9aae5682..62649b83d887d 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1250,6 +1250,16 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 { // ----- +func.func @extract_poison_idx(%arg0: vector<16xf32>) -> f32 { + %0 = vector.extract %arg0[-1]: f32 from vector<16xf32> + return %0 : f32 +} +// CHECK-LABEL: @extract_poison_idx +// CHECK: %[[IDX:.*]] = llvm.mlir.constant(-1 : i64) : i64 +// CHECK: llvm.extractelement {{.*}}[%[[IDX]] : i64] : vector<16xf32> + +// ----- + func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 { %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32> return %0 : f32 diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index fd73cea5e4f30..383215c016039 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -175,6 +175,14 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) { // ----- +func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 { + // expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}} + %0 = vector.extract %arg0[-1] : f32 from vector<4xf32> + return %0: f32 +} + +// ----- + // CHECK-LABEL: @extract_size1_vector // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32> // CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] @@ -256,6 +264,14 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> { // ----- +func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> { + // expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}} + %1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32> + return %1: vector<4xf32> +} + +// ----- + // CHECK-LABEL: @insert_index_vector // CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32> func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index e6f9630d0449a..f9e3b772f9f0a 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -152,6 +152,17 @@ func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> { // ----- +// CHECK-LABEL: @extract_multiple_poison_idx +func.func @extract_multiple_poison_idx(%a: vector<4x5x8xf32>) + -> vector<8xf32> { + // CHECK-NOT: vector.extract + // CHECK-NEXT: ub.poison : vector<8xf32> + %0 = vector.extract %a[-1, -1] : vector<8xf32> from vector<4x5x8xf32> + return %0 : vector<8xf32> +} + +// ----- + // CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false // CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> { @@ -2833,6 +2844,17 @@ func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>) // ----- +// CHECK-LABEL: @insert_multiple_poison_idx +func.func @insert_multiple_poison_idx(%a: vector<4x5x8xf32>, %b: vector<8xf32>) + -> vector<4x5x8xf32> { + // CHECK-NOT: vector.insert + // CHECK-NEXT: ub.poison : vector<4x5x8xf32> + %0 = vector.insert %b, %a[-1, -1] : vector<8xf32> into vector<4x5x8xf32> + return %0 : vector<4x5x8xf32> +} + +// ----- + // CHECK-LABEL: @contiguous_extract_strided_slices_to_extract // CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32> // CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 9416f4787eefb..57e348c7d5991 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -186,7 +186,7 @@ func.func @extract_0d(%arg0: vector) { // ----- func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) { - // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}} + // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension or poison (-1)}} %1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32> } From 3d223051e5e6fa87ac3b6072031db80d9b5237ee Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Mon, 27 Jan 2025 16:19:10 -0800 Subject: [PATCH 3/3] Add to folders --- mlir/include/mlir/Conversion/Passes.td | 5 ++- .../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 2 - .../VectorToSPIRV/VectorToSPIRV.cpp | 1 - .../VectorToSPIRV/VectorToSPIRVPass.cpp | 1 + mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 38 ++++++++++++++----- 5 files changed, 34 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index b547839d76738..4cd6c17e3379c 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1454,7 +1454,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> { let summary = "Convert Vector dialect to SPIR-V dialect"; let constructor = "mlir::createConvertVectorToSPIRVPass()"; - let dependentDialects = ["spirv::SPIRVDialect"]; + let dependentDialects = [ + "spirv::SPIRVDialect", + "ub::UBDialect" + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp index ab9c048f56106..4481c0a497354 100644 --- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp @@ -18,7 +18,6 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" @@ -27,7 +26,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include #define DEBUG_TYPE "convert-to-spirv" diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index d3731db1ce55c..af882cb1ca6e9 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -13,7 +13,6 @@ #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp index 1932de1be603b..cc115b1d36826 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 58ab634c74a4a..b35422f4ca3a9 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1986,12 +1986,26 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) { return fromElementsOp.getElements()[flatIndex]; } -OpFoldResult ExtractOp::fold(FoldAdaptor) { +/// Fold an insert or extract operation into an poison value when a poison index +/// is found at any dimension of the static position. +static ub::PoisonAttr +foldPoisonIndexInsertExtractOp(MLIRContext *context, + ArrayRef staticPos, int64_t poisonVal) { + if (!llvm::is_contained(staticPos, poisonVal)) + return ub::PoisonAttr(); + + return ub::PoisonAttr::get(context); +} + +OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v. // Note: Do not fold "vector.extract %v[] : f32 from vector" (type // mismatch). if (getNumIndices() == 0 && getVector().getType() == getResult().getType()) return getVector(); + if (auto res = foldPoisonIndexInsertExtractOp( + getContext(), adaptor.getStaticPosition(), kPoisonIndex)) + return res; if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) @@ -2262,13 +2276,15 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, /// Fold an insert or extract operation into an poison value when a poison index /// is found at any dimension of the static position. template -LogicalResult foldPoisonIndexInsertExtractOp(OpTy op, - PatternRewriter &rewriter) { - if (!llvm::is_contained(op.getStaticPosition(), OpTy::kPoisonIndex)) - return failure(); +LogicalResult +canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) { + if (auto poisonAttr = foldPoisonIndexInsertExtractOp( + op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) { + rewriter.replaceOpWithNewOp(op, op.getType(), poisonAttr); + return success(); + } - rewriter.replaceOpWithNewOp(op, op.getResult().getType()); - return success(); + return failure(); } } // namespace @@ -2279,7 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context); results.add(foldExtractFromShapeCastToShapeCast); results.add(foldExtractFromFromElements); - results.add(foldPoisonIndexInsertExtractOp); + results.add(canonicalizePoisonIndexInsertExtractOp); } static void populateFromInt64AttrArray(ArrayAttr arrayAttr, @@ -3044,7 +3060,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); - results.add(foldPoisonIndexInsertExtractOp); + results.add(canonicalizePoisonIndexInsertExtractOp); } OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { @@ -3053,6 +3069,10 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { // (type mismatch). if (getNumIndices() == 0 && getSourceType() == getType()) return getSource(); + if (auto res = foldPoisonIndexInsertExtractOp( + getContext(), adaptor.getStaticPosition(), kPoisonIndex)) + return res; + return {}; }