diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index b5504ca84fa42..183893c9fdb46 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -2262,8 +2262,52 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents( } LogicalResult tosa::GatherOp::verify() { - return verifySameElementTypes(*this, /* inType = */ getValues().getType(), - /* outType = */ getOutput().getType()); + if (verifySameElementTypes(*this, /* inType = */ getValues().getType(), + /* outType = */ getOutput().getType()) + .failed()) { + return failure(); + } + + const ShapeAdaptor valuesShape(getValues().getType()); + const ShapeAdaptor indicesShape(getIndices().getType()); + const ShapeAdaptor outputShape(getOutput().getType()); + + int64_t N = ShapedType::kDynamic; + int64_t W = ShapedType::kDynamic; + int64_t C = ShapedType::kDynamic; + + if (valuesShape.hasRank()) { + N = valuesShape.getDimSize(0); + C = valuesShape.getDimSize(2); + } + if (indicesShape.hasRank()) { + const int64_t indicesN = indicesShape.getDimSize(0); + W = indicesShape.getDimSize(1); + if (N == ShapedType::kDynamic) + N = indicesN; + else if (indicesN != ShapedType::kDynamic && N != indicesN) + return emitOpError() << "requires indices dimension 0 to have size " << N + << ", got " << indicesN; + } + if (outputShape.hasRank()) { + const int64_t outputN = outputShape.getDimSize(0); + const int64_t outputW = outputShape.getDimSize(1); + const int64_t outputC = outputShape.getDimSize(2); + if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic && + N != outputN) + return emitOpError() << "requires output dimension 0 to have size " << N + << ", got " << outputN; + + if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic && + W != outputW) + return emitOpError() << "requires output dimension 1 to have size " << W + << ", got " << outputW; + if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic && + C != outputC) + return emitOpError() << "requires output dimension 2 to have size " << C + << ", got " << outputC; + } + return success(); } LogicalResult tosa::ResizeOp::inferReturnTypeComponents( diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index e88fc11d2be88..b23dcd0c9cd3d 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -370,3 +370,36 @@ func.func @test_error_scalar_input_with_per_channel(%arg0: tensor) -> tensor %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor return %0 : tensor } + +// ----- + +// CHECK-LABEL: @test_gather_invalid_indices_N +func.func @test_gather_invalid_indices_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<12x26xi32>) -> tensor<13x26x3xf32> { + // expected-error@+1 {{'tosa.gather' op requires indices dimension 0 to have size 13, got 12}} + %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<12x26xi32>) -> tensor<13x26x3xf32> + return %0 : tensor<13x26x3xf32> +} + +// ----- +// CHECK-LABEL: test_gather_invalid_out_N +func.func @test_gather_invalid_out_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<12x26x3xf32> { + // expected-error@+1 {{'tosa.gather' op requires output dimension 0 to have size 13, got 12}} + %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<12x26x3xf32> + return %0 : tensor<12x26x3xf32> +} + +// ----- +// CHECK-LABEL: test_gather_invalid_out_W +func.func @test_gather_invalid_out_W(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x28x3xf32> { + // expected-error@+1 {{'tosa.gather' op requires output dimension 1 to have size 26, got 28}} + %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x28x3xf32> + return %0 : tensor<13x28x3xf32> +} + +// ----- +// CHECK-LABEL: test_gather_invalid_out_C +func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x8xf32> { + // expected-error@+1 {{'tosa.gather' op requires output dimension 2 to have size 3, got 8}} + %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32> + return %0 : tensor<13x26x8xf32> +}