-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][tosa] Add AllElementTypesMatch trait for tosa.transpose
#120964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This PR adds `AllElementTypesMatch` trait for `tosa.transpose` to ensure output tensor of same type as the input tensor.
Member
|
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis PR adds Full diff: https://github.com/llvm/llvm-project/pull/120964.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..8ae5d3ab417b69 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1698,7 +1698,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
// Operator: transpose
//===----------------------------------------------------------------------===//
def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
- [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ AllElementTypesMatch<["input1", "output"]>]> {
let summary = "Transpose operator";
let description = [{
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39d0ee122b1630..f51c3dbce6eefe 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1002,10 +1002,6 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
return input.reshape(resultTy);
}
- // Transpose does not change the input type.
- if (getInput1().getType() != getType())
- return {};
-
// Transpose is not the identity transpose.
SmallVector<int32_t> perms;
if (getConstantPerms(perms).failed())
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 2902c4a62009e9..8198903b78ac05 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -117,15 +117,6 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
}
-// CHECK-LABEL: @transpose_nofold_quantized_types
-func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
- %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
- %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2xi8>
- // CHECK: tosa.transpose
- %0 = tosa.transpose %input, %perms : (tensor<2x1x1x2xi8>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
- return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
-}
-
// CHECK-LABEL: @transpose_nofold_dense_resource
func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
%0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cca50b25d14d6b..b796a6343e5ed1 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -206,6 +206,15 @@ func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor
// -----
+func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // expected-error@+1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
+ %1 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
%1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>
|
amrami
approved these changes
Dec 23, 2024
GeorgeARM
approved these changes
Dec 30, 2024
qiaojbao
pushed a commit
to GPUOpen-Drivers/llvm-project
that referenced
this pull request
Feb 7, 2025
Local branch amd-gfx 4096bad Merged main:16d19aaedf34 into amd-gfx:9ca74663be2c Remote branch main 79af7bd [mlir][tosa] Add `AllElementTypesMatch` trait for `tosa.transpose` (llvm#120964)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds
AllElementTypesMatchtrait fortosa.transposeto ensure output tensor of same type as the input tensor. Fixes #119364.