Skip to content

Commit 93bf221

Browse files
committed
[mlir][tosa] Add table size check for Table Op
Add table size check for Table Op and add lit tests to error_if_check.mlir also corrected some existing tests that violated the table size checks Signed-off-by: Tai Ly <[email protected]> Change-Id: I34b3dd95d90c473622ae5f18320b688fe4da0b0a
1 parent 801b519 commit 93bf221

File tree

4 files changed

+43
-5
lines changed

4 files changed

+43
-5
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,8 +979,30 @@ bool checkErrorIfResize(Operation *op) {
979979
return true;
980980
}
981981

982+
bool checkErrorIfTable(Operation *op) {
983+
auto table = dyn_cast<tosa::TableOp>(op);
984+
if (!table)
985+
return true;
986+
987+
// REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
988+
auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
989+
int table_size = inputElemType.isInteger(8) ? 256 : 513;
990+
991+
const ShapeAdaptor tableShape(table.getTable().getType());
992+
if (tableShape.hasStaticShape()) {
993+
const auto numElements = tableShape.getNumElements();
994+
if (numElements != table_size) {
995+
op->emitOpError() << "requires table size of " << table_size << ", got "
996+
<< numElements;
997+
return false;
998+
}
999+
}
1000+
1001+
return true;
1002+
}
1003+
9821004
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
983-
if (!checkErrorIfResize(op))
1005+
if (!checkErrorIfResize(op) || !checkErrorIfTable(op))
9841006
return failure();
9851007
return success();
9861008
}

mlir/test/Dialect/Tosa/dynamic_extension.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8
1313

1414
// -----
1515

16-
func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
17-
%0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
16+
func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>) -> () {
17+
%0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<256xi8>) -> tensor<4x5xi8>
1818
return
1919
}
2020

mlir/test/Dialect/Tosa/error_if_check.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,19 @@ func.func @test_resize_invalid_boarder_x(%arg0: tensor<1x8x8x8xf32>) -> tensor<?
8383
%1 = tosa.resize %arg0, %scale, %offset, %border { mode = "BILINEAR" } : (tensor<1x8x8x8xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x?x?x?xf32>
8484
return %1 : tensor<?x?x?x?xf32>
8585
}
86+
87+
// -----
88+
// CHECK-LABEL: test_i16_table_size
89+
func.func @test_i16_table_size(%arg0: tensor<2x64xi16>, %arg1: tensor<256xi16>) -> tensor<2x64xi32> {
90+
// expected-error@+1 {{'tosa.table' op requires table size of 513, got 256}}
91+
%0 = tosa.table %arg0, %arg1 : (tensor<2x64xi16>, tensor<256xi16>) -> tensor<2x64xi32>
92+
return %0 : tensor<2x64xi32>
93+
}
94+
95+
// -----
96+
// CHECK-LABEL: test_i8_table_size
97+
func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) -> tensor<2x64xi8> {
98+
// expected-error@+1 {{'tosa.table' op requires table size of 256, got 513}}
99+
%0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
100+
return %0 : tensor<2x64xi8>
101+
}

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,9 +497,9 @@ func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8
497497

498498
// -----
499499

500-
func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
500+
func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>) -> () {
501501
// expected-error@+1 {{'tosa.table' op expected compile time resolvable constant, but got variable value for operand #1}}
502-
%0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8>
502+
%0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<256xi8>) -> tensor<4x5xi8>
503503
return
504504
}
505505

0 commit comments

Comments
 (0)