Skip to content

Commit 79af7bd

Browse files
authored
[mlir][tosa] Add AllElementTypesMatch trait for tosa.transpose (#120964)
This PR adds `AllElementTypesMatch` trait for `tosa.transpose` to ensure output tensor of same type as the input tensor. Fixes #119364.
1 parent 42dfaa1 commit 79af7bd

File tree

4 files changed

+11
-14
lines changed

4 files changed

+11
-14
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1698,7 +1698,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
16981698
// Operator: transpose
16991699
//===----------------------------------------------------------------------===//
17001700
def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
1701-
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
1701+
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
1702+
AllElementTypesMatch<["input1", "output"]>]> {
17021703
let summary = "Transpose operator";
17031704

17041705
let description = [{

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,10 +1002,6 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
10021002
return input.reshape(resultTy);
10031003
}
10041004

1005-
// Transpose does not change the input type.
1006-
if (getInput1().getType() != getType())
1007-
return {};
1008-
10091005
// Transpose is not the identity transpose.
10101006
SmallVector<int32_t> perms;
10111007
if (getConstantPerms(perms).failed())

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,6 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
117117
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
118118
}
119119

120-
// CHECK-LABEL: @transpose_nofold_quantized_types
121-
func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
122-
%perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
123-
%input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2xi8>
124-
// CHECK: tosa.transpose
125-
%0 = tosa.transpose %input, %perms : (tensor<2x1x1x2xi8>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
126-
return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
127-
}
128-
129120
// CHECK-LABEL: @transpose_nofold_dense_resource
130121
func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
131122
%0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,15 @@ func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor
206206

207207
// -----
208208

209+
func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
210+
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
211+
// expected-error@+1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
212+
%1 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xf32>
213+
return %1 : tensor<3x2xf32>
214+
}
215+
216+
// -----
217+
209218
func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
210219
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
211220
%1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>

0 commit comments

Comments
 (0)