diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 1cfe6eee576b3..349e8edb0ed96 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -150,7 +150,7 @@ class Tosa_TensorOfOrNone allowedTypes, string description = ""> : def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>; -def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>; +def Tosa_ScalarTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_AnyNumber], [1]>]>; def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>; def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>; diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index ed747145369d7..716362e094652 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -300,7 +300,7 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>) { %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> %1 = "tosa.const"() {values = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32> - // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}} + // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}} %2 = tosa.pad %arg0, %0, %1 : (tensor<13x21xf32>, !tosa.shape<4>, tensor<2xf32>) -> tensor<13x21xf32> return } @@ -1006,7 +1006,7 @@ func.func @test_non_tosa_ops() { func.func @test_pad_rank0_pad_const(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E5M2> { %padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6> %cst = "tosa.const"() { values = dense<-0.0> : tensor } : () -> tensor - // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor'}} + // expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of number values, but got 'tensor'}} %0 = tosa.pad %arg0, %padding, %cst : (tensor<13x21x3xf8E4M3FN>, !tosa.shape<6>, tensor) -> tensor<13x21x3xf8E5M2> return %0 : tensor<13x21x3xf8E5M2> } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index d0f40279421f4..9d43f8998da93 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -344,6 +344,19 @@ func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: ten // ----- +// CHECK-LABEL: @test_accepts_unranked_scalar_tensor +func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: tensor<1xf32>) -> tensor<*xf32> { + // CHECK: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32> + %0 = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<*xf32> + // CHECK: %[[SHAPE:.*]] = tosa.const_shape + %1 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6> + // CHECK: tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32> + %2 = tosa.pad %arg0, %1, %0 : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: @test_table_static func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () { // CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>