-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][tosa] Add expected output shape check to argmax verifier #129870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tosa] Add expected output shape check to argmax verifier #129870
Conversation
Fixes some test cases which incorrectly declared the output shape and added a negative test case. Signed-off-by: Luke Hutton <[email protected]> Change-Id: I7b757d944ec0b2f168fd4ca4ea395249c78c3341
|
@llvm/pr-subscribers-mlir-tosa Author: Luke Hutton (lhutton1) ChangesFixes some test cases which incorrectly declared the output shape and added a negative test case. Full diff: https://github.com/llvm/llvm-project/pull/129870.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 800968e6f4766..bd5c5e56398c1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -438,17 +438,34 @@ static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
}
LogicalResult tosa::ArgMaxOp::verify() {
+ const ShapedType resultType = llvm::cast<ShapedType>(getType());
+
// Ensure output is of 32-bit integer
- const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
- if (!resultETy.isIntOrIndex())
+ if (const auto resultETy = resultType.getElementType();
+ !resultETy.isIntOrIndex())
return emitOpError("result tensor is not of integer type");
- // Ensure axis is within the tensor rank
const auto inputType = llvm::cast<ShapedType>(getInput().getType());
+ if (!inputType.hasRank())
+ return success();
+
+ // Ensure axis is within the tensor rank
const int64_t axis = getAxisAttr().getInt();
- if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
+ if (((axis < 0) || axis >= inputType.getRank()))
return emitOpError("specified axis is outside the rank of the tensor");
+ if (!resultType.hasRank())
+ return success();
+
+ const ArrayRef<int64_t> inputShape = inputType.getShape();
+ const ArrayRef<int64_t> outputShape = resultType.getShape();
+ llvm::SmallVector<int64_t> expectedOutputShape(inputShape.begin(),
+ inputShape.end());
+ expectedOutputShape.erase(expectedOutputShape.begin() + axis);
+ if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
+ return emitOpError("expected output shape '")
+ << expectedOutputShape << "', got '" << outputShape << "'";
+
return success();
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index a0184e2d82704..09aba79647c79 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,10 +1,10 @@
// RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s
// CHECK-LABEL: @argmax_nofold
-func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
+func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
// CHECK: tosa.argmax
- %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
- return %0 : tensor<?x1xi32>
+ %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<1xi32>
+ return %0 : tensor<1xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 8c3ad828ab06f..e06efbbfa1ad9 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -5,7 +5,7 @@
// -----
// Uses argmax as canonical example to validate constrained TOSA tensor shapes.
// CHECK-LABEL: argmax
-func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
- %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
- return %0 : tensor<?xi32>
+func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<i32> {
+ %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<i32>
+ return %0 : tensor<i32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e665510ff0143..76093b0b3c1ca 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1392,3 +1392,11 @@ func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (te
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
}
+
+// -----
+
+func.func @test_argmax_invalid_output_shape(%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
+ // expected-error@+1 {{'tosa.argmax' op expected output shape '2, 3', got '1, 2, 3'}}
+ %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<1x2x3xf32>) -> tensor<1x2x3xi32>
+ return %0 : tensor<1x2x3xi32>
+}
|
|
@llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesFixes some test cases which incorrectly declared the output shape and added a negative test case. Full diff: https://github.com/llvm/llvm-project/pull/129870.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 800968e6f4766..bd5c5e56398c1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -438,17 +438,34 @@ static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
}
LogicalResult tosa::ArgMaxOp::verify() {
+ const ShapedType resultType = llvm::cast<ShapedType>(getType());
+
// Ensure output is of 32-bit integer
- const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
- if (!resultETy.isIntOrIndex())
+ if (const auto resultETy = resultType.getElementType();
+ !resultETy.isIntOrIndex())
return emitOpError("result tensor is not of integer type");
- // Ensure axis is within the tensor rank
const auto inputType = llvm::cast<ShapedType>(getInput().getType());
+ if (!inputType.hasRank())
+ return success();
+
+ // Ensure axis is within the tensor rank
const int64_t axis = getAxisAttr().getInt();
- if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
+ if (((axis < 0) || axis >= inputType.getRank()))
return emitOpError("specified axis is outside the rank of the tensor");
+ if (!resultType.hasRank())
+ return success();
+
+ const ArrayRef<int64_t> inputShape = inputType.getShape();
+ const ArrayRef<int64_t> outputShape = resultType.getShape();
+ llvm::SmallVector<int64_t> expectedOutputShape(inputShape.begin(),
+ inputShape.end());
+ expectedOutputShape.erase(expectedOutputShape.begin() + axis);
+ if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
+ return emitOpError("expected output shape '")
+ << expectedOutputShape << "', got '" << outputShape << "'";
+
return success();
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index a0184e2d82704..09aba79647c79 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,10 +1,10 @@
// RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s
// CHECK-LABEL: @argmax_nofold
-func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
+func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
// CHECK: tosa.argmax
- %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
- return %0 : tensor<?x1xi32>
+ %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<1xi32>
+ return %0 : tensor<1xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 8c3ad828ab06f..e06efbbfa1ad9 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -5,7 +5,7 @@
// -----
// Uses argmax as canonical example to validate constrained TOSA tensor shapes.
// CHECK-LABEL: argmax
-func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
- %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
- return %0 : tensor<?xi32>
+func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<i32> {
+ %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<i32>
+ return %0 : tensor<i32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e665510ff0143..76093b0b3c1ca 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1392,3 +1392,11 @@ func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (te
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
}
+
+// -----
+
+func.func @test_argmax_invalid_output_shape(%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
+ // expected-error@+1 {{'tosa.argmax' op expected output shape '2, 3', got '1, 2, 3'}}
+ %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<1x2x3xf32>) -> tensor<1x2x3xi32>
+ return %0 : tensor<1x2x3xi32>
+}
|
FranklandJack
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Fixes some test cases which incorrectly declared the output shape and added a negative test case.