-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][Tosa] Allow scalar tensors to be unranked in the TOSA dialect #150399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Tosa] Allow scalar tensors to be unranked in the TOSA dialect #150399
Conversation
|
@llvm/pr-subscribers-mlir Author: woruyu (woruyu) ChangesSummaryThis PR resolves #149779 Full diff: https://github.com/llvm/llvm-project/pull/150399.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 1cfe6eee576b3..349e8edb0ed96 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -150,7 +150,7 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
-def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
+def Tosa_ScalarTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_AnyNumber], [1]>]>;
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index ed747145369d7..716362e094652 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -300,7 +300,7 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens
func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%1 = "tosa.const"() {values = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
- // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}}
+ // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}}
%2 = tosa.pad %arg0, %0, %1 : (tensor<13x21xf32>, !tosa.shape<4>, tensor<2xf32>) -> tensor<13x21xf32>
return
}
@@ -1006,7 +1006,7 @@ func.func @test_non_tosa_ops() {
func.func @test_pad_rank0_pad_const(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E5M2> {
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
%cst = "tosa.const"() { values = dense<-0.0> : tensor<f8E4M3FN> } : () -> tensor<f8E4M3FN>
- // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<f8E4M3FN>'}}
+ // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of number values, but got 'tensor<f8E4M3FN>'}}
%0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<6>, tensor<f8E4M3FN>) -> tensor<13x21x3xf8E5M2>
return %0 : tensor<13x21x3xf8E5M2>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d0f40279421f4..9d43f8998da93 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -344,6 +344,19 @@ func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: ten
// -----
+// CHECK-LABEL: @test_accepts_unranked_scalar_tensor
+func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: tensor<1xf32>) -> tensor<*xf32> {
+ // CHECK: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32>
+ %0 = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<*xf32>
+ // CHECK: %[[SHAPE:.*]] = tosa.const_shape
+ %1 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ // CHECK: tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32>
+ %2 = tosa.pad %arg0, %1, %0 : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<*xf32>) -> tensor<*xf32>
+ return %2 : tensor<*xf32>
+}
+
+// -----
+
// CHECK-LABEL: @test_table_static
func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () {
// CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>
|
|
@llvm/pr-subscribers-mlir-tosa Author: woruyu (woruyu) ChangesSummaryThis PR resolves #149779 Full diff: https://github.com/llvm/llvm-project/pull/150399.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 1cfe6eee576b3..349e8edb0ed96 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -150,7 +150,7 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
-def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
+def Tosa_ScalarTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_AnyNumber], [1]>]>;
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index ed747145369d7..716362e094652 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -300,7 +300,7 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens
func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%1 = "tosa.const"() {values = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
- // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}}
+ // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}}
%2 = tosa.pad %arg0, %0, %1 : (tensor<13x21xf32>, !tosa.shape<4>, tensor<2xf32>) -> tensor<13x21xf32>
return
}
@@ -1006,7 +1006,7 @@ func.func @test_non_tosa_ops() {
func.func @test_pad_rank0_pad_const(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E5M2> {
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
%cst = "tosa.const"() { values = dense<-0.0> : tensor<f8E4M3FN> } : () -> tensor<f8E4M3FN>
- // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<f8E4M3FN>'}}
+ // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of number values, but got 'tensor<f8E4M3FN>'}}
%0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<6>, tensor<f8E4M3FN>) -> tensor<13x21x3xf8E5M2>
return %0 : tensor<13x21x3xf8E5M2>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d0f40279421f4..9d43f8998da93 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -344,6 +344,19 @@ func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: ten
// -----
+// CHECK-LABEL: @test_accepts_unranked_scalar_tensor
+func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: tensor<1xf32>) -> tensor<*xf32> {
+ // CHECK: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32>
+ %0 = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<*xf32>
+ // CHECK: %[[SHAPE:.*]] = tosa.const_shape
+ %1 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ // CHECK: tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32>
+ %2 = tosa.pad %arg0, %1, %0 : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<*xf32>) -> tensor<*xf32>
+ return %2 : tensor<*xf32>
+}
+
+// -----
+
// CHECK-LABEL: @test_table_static
func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () {
// CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>
|
lhutton1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @woruyu, LGTM!
|
It would also be good to allow |
Thanks for the suggestion! I'm happy to do that and will plan to submit a follow-up patch shortly. |
…lvm#150399) ### Summary This PR resolves llvm#149779
Summary
This PR resolves #149779