Skip to content

Conversation

@Tai78641
Copy link
Contributor

Add table size check for Table Op
and add lit tests to error_if_check.mlir
also corrected some existing tests that violated the table size checks

@llvmbot
Copy link
Member

llvmbot commented Apr 10, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

Add table size check for Table Op
and add lit tests to error_if_check.mlir
also corrected some existing tests that violated the table size checks


Full diff: https://github.com/llvm/llvm-project/pull/135262.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+23-1)
  • (modified) mlir/test/Dialect/Tosa/dynamic_extension.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+16)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+2-2)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 28e562c813eb3..7b731be9ec8a5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -979,8 +979,30 @@ bool checkErrorIfResize(Operation *op) {
   return true;
 }
 
+bool checkErrorIfTable(Operation *op) {
+  auto table = dyn_cast<tosa::TableOp>(op);
+  if (!table)
+    return true;
+
+  // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
+  auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
+  int table_size = inputElemType.isInteger(8) ? 256 : 513;
+
+  const ShapeAdaptor tableShape(table.getTable().getType());
+  if (tableShape.hasStaticShape()) {
+    const auto numElements = tableShape.getNumElements();
+    if (numElements != table_size) {
+      op->emitOpError() << "requires table size of " << table_size << ", got "
+                        << numElements;
+      return false;
+    }
+  }
+
+  return true;
+}
+
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
-  if (!checkErrorIfResize(op))
+  if (!checkErrorIfResize(op) || !checkErrorIfTable(op))
     return failure();
   return success();
 }
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index 0ec46022157d7..25e1aa195c3a0 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -13,8 +13,8 @@ func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8
 
 // -----
 
-func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
-  %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
+func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>) -> () {
+  %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<256xi8>) -> tensor<4x5xi8>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index ce3ad04ea68ca..45ba8f76629d6 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -83,3 +83,19 @@ func.func @test_resize_invalid_boarder_x(%arg0: tensor<1x8x8x8xf32>) -> tensor<?
   %1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
   return %1 : tensor<?x?x?x?xf32>
 }
+
+// -----
+// CHECK-LABEL: test_i16_table_size
+func.func @test_i16_table_size(%arg0: tensor<2x64xi16>, %arg1: tensor<256xi16>) -> tensor<2x64xi32> {
+  // expected-error@+1 {{'tosa.table' op requires table size of 513, got 256}}
+    %0 = tosa.table %arg0, %arg1 : (tensor<2x64xi16>, tensor<256xi16>) -> tensor<2x64xi32>
+    return %0 : tensor<2x64xi32>
+}
+
+// -----
+// CHECK-LABEL: test_i8_table_size
+func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) -> tensor<2x64xi8> {
+  // expected-error@+1 {{'tosa.table' op requires table size of 256, got 513}}
+    %0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
+    return %0 : tensor<2x64xi8>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 241e603e91c61..7386b1ba9df99 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -497,9 +497,9 @@ func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8
 
 // -----
 
-func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
+func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>) -> () {
   // expected-error@+1 {{'tosa.table' op expected compile time resolvable constant, but got variable value for operand #1}}
-  %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
+  %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<256xi8>) -> tensor<4x5xi8>
   return
 }
 

@Jerry-Ge Jerry-Ge requested review from GeorgeARM and lhutton1 April 10, 2025 21:28
Add table size check for Table Op
and add lit tests to error_if_check.mlir
also corrected some existing tests that violated the
table size checks

Signed-off-by: Tai Ly <[email protected]>
Change-Id: I34b3dd95d90c473622ae5f18320b688fe4da0b0a
@Tai78641
Copy link
Contributor Author

rebased

Copy link
Member

@Jerry-Ge Jerry-Ge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the PR!

@Jerry-Ge Jerry-Ge merged commit 96064e1 into llvm:main Apr 15, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants