Skip to content

Conversation

@CoTinker
Copy link
Contributor

This PR adds AllElementTypesMatch trait for tosa.transpose to ensure output tensor of same type as the input tensor. Fixes #119364.

This PR adds `AllElementTypesMatch` trait for `tosa.transpose` to ensure output tensor of same type as the input tensor.
@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2024

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This PR adds AllElementTypesMatch trait for tosa.transpose to ensure output tensor of same type as the input tensor. Fixes #119364.


Full diff: https://github.com/llvm/llvm-project/pull/120964.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (-4)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (-9)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+9)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..8ae5d3ab417b69 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1698,7 +1698,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
 // Operator: transpose
 //===----------------------------------------------------------------------===//
 def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
-                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+                 AllElementTypesMatch<["input1", "output"]>]> {
   let summary = "Transpose operator";
 
   let description = [{
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39d0ee122b1630..f51c3dbce6eefe 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1002,10 +1002,6 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
       return input.reshape(resultTy);
   }
 
-  // Transpose does not change the input type.
-  if (getInput1().getType() != getType())
-    return {};
-
   // Transpose is not the identity transpose.
   SmallVector<int32_t> perms;
   if (getConstantPerms(perms).failed())
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 2902c4a62009e9..8198903b78ac05 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -117,15 +117,6 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
   return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
 }
 
-// CHECK-LABEL: @transpose_nofold_quantized_types
-func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
-  %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-  %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2xi8>
-  // CHECK: tosa.transpose
-  %0 = tosa.transpose %input, %perms : (tensor<2x1x1x2xi8>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
-  return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
-}
-
 // CHECK-LABEL: @transpose_nofold_dense_resource
 func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
   %0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cca50b25d14d6b..b796a6343e5ed1 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -206,6 +206,15 @@ func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor
 
 // -----
 
+func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // expected-error@+1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
+  %1 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
 func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
   %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
   %1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>

@CoTinker CoTinker merged commit 79af7bd into llvm:main Dec 30, 2024
11 checks passed
@CoTinker CoTinker deleted the transpose branch December 30, 2024 15:13
qiaojbao pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Feb 7, 2025
Local branch amd-gfx 4096bad Merged main:16d19aaedf34 into amd-gfx:9ca74663be2c
Remote branch main 79af7bd [mlir][tosa] Add `AllElementTypesMatch` trait for `tosa.transpose` (llvm#120964)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir] --canonicalize crashes in tosa::ReshapeOp::fold

4 participants