Skip to content

Commit 4ea8638

Browse files
committed
[mlir][tosa] Add expected output shape check to argmax verifier
Fixes some test cases which incorrectly declared the output shape and added a negative test case. Signed-off-by: Luke Hutton <[email protected]> Change-Id: I7b757d944ec0b2f168fd4ca4ea395249c78c3341
1 parent 0953706 commit 4ea8638

File tree

4 files changed

+35
-10
lines changed

4 files changed

+35
-10
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,17 +438,34 @@ static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
438438
}
439439

440440
LogicalResult tosa::ArgMaxOp::verify() {
441+
const ShapedType resultType = llvm::cast<ShapedType>(getType());
442+
441443
// Ensure output is of 32-bit integer
442-
const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
443-
if (!resultETy.isIntOrIndex())
444+
if (const auto resultETy = resultType.getElementType();
445+
!resultETy.isIntOrIndex())
444446
return emitOpError("result tensor is not of integer type");
445447

446-
// Ensure axis is within the tensor rank
447448
const auto inputType = llvm::cast<ShapedType>(getInput().getType());
449+
if (!inputType.hasRank())
450+
return success();
451+
452+
// Ensure axis is within the tensor rank
448453
const int64_t axis = getAxisAttr().getInt();
449-
if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
454+
if (((axis < 0) || axis >= inputType.getRank()))
450455
return emitOpError("specified axis is outside the rank of the tensor");
451456

457+
if (!resultType.hasRank())
458+
return success();
459+
460+
const ArrayRef<int64_t> inputShape = inputType.getShape();
461+
const ArrayRef<int64_t> outputShape = resultType.getShape();
462+
llvm::SmallVector<int64_t> expectedOutputShape(inputShape.begin(),
463+
inputShape.end());
464+
expectedOutputShape.erase(expectedOutputShape.begin() + axis);
465+
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
466+
return emitOpError("expected output shape '")
467+
<< expectedOutputShape << "', got '" << outputShape << "'";
468+
452469
return success();
453470
}
454471

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s
22

33
// CHECK-LABEL: @argmax_nofold
4-
func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
4+
func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
55
// CHECK: tosa.argmax
6-
%0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
7-
return %0 : tensor<?x1xi32>
6+
%0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<1xi32>
7+
return %0 : tensor<1xi32>
88
}
99

1010
// -----

mlir/test/Dialect/Tosa/constrained_shapes.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// -----
66
// Uses argmax as canonical example to validate constrained TOSA tensor shapes.
77
// CHECK-LABEL: argmax
8-
func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
9-
%0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
10-
return %0 : tensor<?xi32>
8+
func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<i32> {
9+
%0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<i32>
10+
return %0 : tensor<i32>
1111
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,3 +1392,11 @@ func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (te
13921392
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
13931393
return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
13941394
}
1395+
1396+
// -----
1397+
1398+
func.func @test_argmax_invalid_output_shape(%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
1399+
// expected-error@+1 {{'tosa.argmax' op expected output shape '2, 3', got '1, 2, 3'}}
1400+
%0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<1x2x3xf32>) -> tensor<1x2x3xi32>
1401+
return %0 : tensor<1x2x3xi32>
1402+
}

0 commit comments

Comments
 (0)