Skip to content

Commit de1f54a

Browse files
authored
[mlir][Tosa] Allow scalar tensors to be unranked in the TOSA dialect (#150399)
### Summary This PR resolves #149779
1 parent bba8467 commit de1f54a

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
150150

151151
def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
152152

153-
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
153+
def Tosa_ScalarTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_AnyNumber], [1]>]>;
154154
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
155155
def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>;
156156

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens
300300
func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>) {
301301
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
302302
%1 = "tosa.const"() {values = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
303-
// expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}}
303+
// expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}}
304304
%2 = tosa.pad %arg0, %0, %1 : (tensor<13x21xf32>, !tosa.shape<4>, tensor<2xf32>) -> tensor<13x21xf32>
305305
return
306306
}
@@ -1006,7 +1006,7 @@ func.func @test_non_tosa_ops() {
10061006
func.func @test_pad_rank0_pad_const(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E5M2> {
10071007
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
10081008
%cst = "tosa.const"() { values = dense<-0.0> : tensor<f8E4M3FN> } : () -> tensor<f8E4M3FN>
1009-
// expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<f8E4M3FN>'}}
1009+
// expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of number values, but got 'tensor<f8E4M3FN>'}}
10101010
%0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<6>, tensor<f8E4M3FN>) -> tensor<13x21x3xf8E5M2>
10111011
return %0 : tensor<13x21x3xf8E5M2>
10121012
}

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,19 @@ func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: ten
344344

345345
// -----
346346

347+
// CHECK-LABEL: @test_accepts_unranked_scalar_tensor
348+
func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: tensor<1xf32>) -> tensor<*xf32> {
349+
// CHECK: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32>
350+
%0 = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<*xf32>
351+
// CHECK: %[[SHAPE:.*]] = tosa.const_shape
352+
%1 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
353+
// CHECK: tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32>
354+
%2 = tosa.pad %arg0, %1, %0 : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<*xf32>) -> tensor<*xf32>
355+
return %2 : tensor<*xf32>
356+
}
357+
358+
// -----
359+
347360
// CHECK-LABEL: @test_table_static
348361
func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () {
349362
// CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>

0 commit comments

Comments
 (0)