Skip to content

Commit fbdf4ec

Browse files
authored
[mlir][tosa] Fix invalid data type combinations check (llvm#150066)
Previously this check assumed that if an operator exists in profile complimance (TosaProfileComplianceData.h.inc), an entry exists in both the profiles and extensions section. However, this is not necessarily the case. This commit changes the check such that it doesn't assume the above. In doing so, it allows more operators to be checked for invalid data type combinations, which were otherwise skipped previously.
1 parent 9164d20 commit fbdf4ec

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,12 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
464464
CheckCondition condition = CheckCondition::invalid;
465465
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
466466
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
467+
if (failed(maybeProfDef) && failed(maybeExtDef))
468+
return success();
467469

468-
if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
469-
!maybeProfDef.value().size() && !maybeExtDef.value().size()) {
470+
const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
471+
(succeeded(maybeExtDef) && !maybeExtDef->empty());
472+
if (!hasEntry) {
470473
std::string message;
471474
llvm::raw_string_ostream os(message);
472475
os << "illegal: operation operand/result data types did not align with any "

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2036,3 +2036,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
20362036
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
20372037
return %0 : tensor<2x52x3xf32>
20382038
}
2039+
2040+
// -----
2041+
2042+
func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
2043+
// expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
2044+
%0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
2045+
return %0 : tensor<1x12x11xf32>
2046+
}
2047+
2048+
// -----
2049+
2050+
func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) {
2051+
// expected-error@+1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}}
2052+
%0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>)
2053+
return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16>
2054+
}

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens
4848

4949
// -----
5050

51-
func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
51+
func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
5252
// expected-error@+1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
53-
%0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
54-
return %0 : tensor<1x1x1x1x13x21x3xf32>
53+
%0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
54+
return %0 : tensor<1x1x1x1x13x21x3xi32>
5555
}
5656

5757
// -----

0 commit comments

Comments
 (0)