diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 800968e6f4766..bd5c5e56398c1 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -438,17 +438,34 @@ static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { } LogicalResult tosa::ArgMaxOp::verify() { + const ShapedType resultType = llvm::cast(getType()); + // Ensure output is of 32-bit integer - const auto resultETy = llvm::cast(getType()).getElementType(); - if (!resultETy.isIntOrIndex()) + if (const auto resultETy = resultType.getElementType(); + !resultETy.isIntOrIndex()) return emitOpError("result tensor is not of integer type"); - // Ensure axis is within the tensor rank const auto inputType = llvm::cast(getInput().getType()); + if (!inputType.hasRank()) + return success(); + + // Ensure axis is within the tensor rank const int64_t axis = getAxisAttr().getInt(); - if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank())) + if (((axis < 0) || axis >= inputType.getRank())) return emitOpError("specified axis is outside the rank of the tensor"); + if (!resultType.hasRank()) + return success(); + + const ArrayRef inputShape = inputType.getShape(); + const ArrayRef outputShape = resultType.getShape(); + llvm::SmallVector expectedOutputShape(inputShape.begin(), + inputShape.end()); + expectedOutputShape.erase(expectedOutputShape.begin() + axis); + if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) + return emitOpError("expected output shape '") + << expectedOutputShape << "', got '" << outputShape << "'"; + return success(); } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index a0184e2d82704..09aba79647c79 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1,10 +1,10 @@ // RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s // CHECK-LABEL: @argmax_nofold -func.func @argmax_nofold(%arg0: tensor) -> tensor { +func.func @argmax_nofold(%arg0: tensor) -> tensor<1xi32> { // CHECK: tosa.argmax - %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor) -> tensor - return %0 : tensor + %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor) -> tensor<1xi32> + return %0 : tensor<1xi32> } // ----- diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir index 8c3ad828ab06f..e06efbbfa1ad9 100644 --- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir +++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir @@ -5,7 +5,7 @@ // ----- // Uses argmax as canonical example to validate constrained TOSA tensor shapes. // CHECK-LABEL: argmax -func.func @test_argmax(%arg0: tensor) -> tensor { - %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor) -> tensor - return %0 : tensor +func.func @test_argmax(%arg0: tensor) -> tensor { + %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor) -> tensor + return %0 : tensor } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index e665510ff0143..76093b0b3c1ca 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1392,3 +1392,11 @@ func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (te %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>) return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16> } + +// ----- + +func.func @test_argmax_invalid_output_shape(%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { + // expected-error@+1 {{'tosa.argmax' op expected output shape '2, 3', got '1, 2, 3'}} + %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<1x2x3xf32>) -> tensor<1x2x3xi32> + return %0 : tensor<1x2x3xi32> +}