From 09691665d988eefd86b7a25b41b036b359401927 Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Mon, 23 Dec 2024 19:34:04 +0800 Subject: [PATCH] [mlir][tosa] Add `AllElementTypesMatch` trait for `tosa.transpose` This PR adds `AllElementTypesMatch` trait for `tosa.transpose` to ensure output tensor of same type as the input tensor. --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 3 ++- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 ---- mlir/test/Dialect/Tosa/constant-op-fold.mlir | 9 --------- mlir/test/Dialect/Tosa/invalid.mlir | 9 +++++++++ 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index e3c725801d162..8ae5d3ab417b6 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1698,7 +1698,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> { // Operator: transpose //===----------------------------------------------------------------------===// def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose", - [DeclareOpInterfaceMethods]> { + [DeclareOpInterfaceMethods, + AllElementTypesMatch<["input1", "output"]>]> { let summary = "Transpose operator"; let description = [{ diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 39d0ee122b163..f51c3dbce6eef 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1002,10 +1002,6 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { return input.reshape(resultTy); } - // Transpose does not change the input type. - if (getInput1().getType() != getType()) - return {}; - // Transpose is not the identity transpose. SmallVector perms; if (getConstantPerms(perms).failed()) diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 2902c4a62009e..8198903b78ac0 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -117,15 +117,6 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) return %1, %input : tensor<3x2xf32>, tensor<2x3xf32> } -// CHECK-LABEL: @transpose_nofold_quantized_types -func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> { - %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32> - %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2xi8> - // CHECK: tosa.transpose - %0 = tosa.transpose %input, %perms : (tensor<2x1x1x2xi8>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> - return %0: tensor<1x1x2x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> -} - // CHECK-LABEL: @transpose_nofold_dense_resource func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> { %0 = "tosa.const"() <{value = dense_resource : tensor<2x2xf32>}> : () -> tensor<2x2xf32> diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index cca50b25d14d6..b796a6343e5ed 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -206,6 +206,15 @@ func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor // ----- +func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> { + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error@+1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}} + %1 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> { %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> %1 = tosa.reshape %arg0 {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>