Skip to content

Commit e76d236

Browse files
amramirsuderman
authored andcommitted
[mlir] tosa.concat - fix isCompatibleReturnTypes()
Reviewed By: eric-k256 Differential Revision: https://reviews.llvm.org/D146901
1 parent 88a097d commit e76d236

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,8 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
425425
bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
426426
if (l.size() != r.size() || l.size() != 1)
427427
return false;
428+
if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0]))
429+
return false;
428430
return succeeded(verifyCompatibleShape(l[0], r[0]));
429431
}
430432

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
4646

4747
// -----
4848

49+
func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
50+
// expected-error@+1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
51+
%0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
52+
return %0 : tensor<?x?xi8>
53+
}
54+
55+
// -----
56+
4957
func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
5058
// expected-error@+1 {{'tosa.pad' op padding of pad is not constant}}
5159
%0 = "tosa.pad"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>

0 commit comments

Comments
 (0)