Skip to content

Conversation

@lhutton1
Copy link
Contributor

@lhutton1 lhutton1 commented Mar 5, 2025

Fixes some test cases which incorrectly declared the output shape and added a negative test case.

Fixes some test cases which incorrectly declared the output shape
and added a negative test case.

Signed-off-by: Luke Hutton <[email protected]>
Change-Id: I7b757d944ec0b2f168fd4ca4ea395249c78c3341
@llvmbot
Copy link
Member

llvmbot commented Mar 5, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

Fixes some test cases which incorrectly declared the output shape and added a negative test case.


Full diff: https://github.com/llvm/llvm-project/pull/129870.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+21-4)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/constrained_shapes.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+8)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 800968e6f4766..bd5c5e56398c1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -438,17 +438,34 @@ static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
 }
 
 LogicalResult tosa::ArgMaxOp::verify() {
+  const ShapedType resultType = llvm::cast<ShapedType>(getType());
+
   // Ensure output is of 32-bit integer
-  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
-  if (!resultETy.isIntOrIndex())
+  if (const auto resultETy = resultType.getElementType();
+      !resultETy.isIntOrIndex())
     return emitOpError("result tensor is not of integer type");
 
-  // Ensure axis is within the tensor rank
   const auto inputType = llvm::cast<ShapedType>(getInput().getType());
+  if (!inputType.hasRank())
+    return success();
+
+  // Ensure axis is within the tensor rank
   const int64_t axis = getAxisAttr().getInt();
-  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
+  if (((axis < 0) || axis >= inputType.getRank()))
     return emitOpError("specified axis is outside the rank of the tensor");
 
+  if (!resultType.hasRank())
+    return success();
+
+  const ArrayRef<int64_t> inputShape = inputType.getShape();
+  const ArrayRef<int64_t> outputShape = resultType.getShape();
+  llvm::SmallVector<int64_t> expectedOutputShape(inputShape.begin(),
+                                                 inputShape.end());
+  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
+  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
+    return emitOpError("expected output shape '")
+           << expectedOutputShape << "', got '" << outputShape << "'";
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index a0184e2d82704..09aba79647c79 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
-func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
+func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
   // CHECK: tosa.argmax
-  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
-  return %0 : tensor<?x1xi32>
+  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 8c3ad828ab06f..e06efbbfa1ad9 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -5,7 +5,7 @@
 // -----
 // Uses argmax as canonical example to validate constrained TOSA tensor shapes.
 // CHECK-LABEL: argmax
-func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
-  %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
-  return %0 : tensor<?xi32>
+func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<i32> {
+  %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<i32>
+  return %0 : tensor<i32>
 }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e665510ff0143..76093b0b3c1ca 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1392,3 +1392,11 @@ func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (te
   %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
   return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
 }
+
+// -----
+
+func.func @test_argmax_invalid_output_shape(%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
+  // expected-error@+1 {{'tosa.argmax' op expected output shape '2, 3', got '1, 2, 3'}}
+  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<1x2x3xf32>) -> tensor<1x2x3xi32>
+  return %0 : tensor<1x2x3xi32>
+}

@llvmbot
Copy link
Member

llvmbot commented Mar 5, 2025

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

Fixes some test cases which incorrectly declared the output shape and added a negative test case.


Full diff: https://github.com/llvm/llvm-project/pull/129870.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+21-4)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/constrained_shapes.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+8)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 800968e6f4766..bd5c5e56398c1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -438,17 +438,34 @@ static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
 }
 
 LogicalResult tosa::ArgMaxOp::verify() {
+  const ShapedType resultType = llvm::cast<ShapedType>(getType());
+
   // Ensure output is of 32-bit integer
-  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
-  if (!resultETy.isIntOrIndex())
+  if (const auto resultETy = resultType.getElementType();
+      !resultETy.isIntOrIndex())
     return emitOpError("result tensor is not of integer type");
 
-  // Ensure axis is within the tensor rank
   const auto inputType = llvm::cast<ShapedType>(getInput().getType());
+  if (!inputType.hasRank())
+    return success();
+
+  // Ensure axis is within the tensor rank
   const int64_t axis = getAxisAttr().getInt();
-  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
+  if (((axis < 0) || axis >= inputType.getRank()))
     return emitOpError("specified axis is outside the rank of the tensor");
 
+  if (!resultType.hasRank())
+    return success();
+
+  const ArrayRef<int64_t> inputShape = inputType.getShape();
+  const ArrayRef<int64_t> outputShape = resultType.getShape();
+  llvm::SmallVector<int64_t> expectedOutputShape(inputShape.begin(),
+                                                 inputShape.end());
+  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
+  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
+    return emitOpError("expected output shape '")
+           << expectedOutputShape << "', got '" << outputShape << "'";
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index a0184e2d82704..09aba79647c79 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
-func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
+func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
   // CHECK: tosa.argmax
-  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
-  return %0 : tensor<?x1xi32>
+  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 8c3ad828ab06f..e06efbbfa1ad9 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -5,7 +5,7 @@
 // -----
 // Uses argmax as canonical example to validate constrained TOSA tensor shapes.
 // CHECK-LABEL: argmax
-func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
-  %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
-  return %0 : tensor<?xi32>
+func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<i32> {
+  %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<i32>
+  return %0 : tensor<i32>
 }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e665510ff0143..76093b0b3c1ca 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1392,3 +1392,11 @@ func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (te
   %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
   return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
 }
+
+// -----
+
+func.func @test_argmax_invalid_output_shape(%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
+  // expected-error@+1 {{'tosa.argmax' op expected output shape '2, 3', got '1, 2, 3'}}
+  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<1x2x3xf32>) -> tensor<1x2x3xi32>
+  return %0 : tensor<1x2x3xi32>
+}

@FranklandJack FranklandJack self-requested a review March 5, 2025 13:54
Copy link
Contributor

@FranklandJack FranklandJack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@GeorgeARM GeorgeARM merged commit 4f46b75 into llvm:main Mar 6, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants