diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 536551c8f8437..1cfe6eee576b3 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -152,7 +152,7 @@ def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>; def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>; def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>; -def Tosa_ScalarIntOrFloatTensor : TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>; +def Tosa_ScalarIntOrFloatTensor : AnyTypeOf<[TosaUnrankedTensorOf<[Tosa_Int, AnyFloat]>, TosaScalarTensorOf<[Tosa_Int, AnyFloat], [1]>]>; // We include unranked tensors as a supported type for all possible tosa // Tensors as unranked does not guarantee invalid. If unranked tensors exist diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 805522799a6d8..6fa6756822acb 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1007,7 +1007,7 @@ func.func @test_pad_rank0_pad_const(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<1 func.func @test_conv2d_rank0_zp(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi32> { %input_zp = "tosa.const"() <{values = dense<0> : tensor}> : () -> tensor %weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> - // expected-error@+1 {{'tosa.conv2d' op operand #3 must be tosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values, but got 'tensor'}} + // expected-error@+1 {{'tosa.conv2d' op operand #3 must be tosa-conformant unranked tensor of unsigned integer or signless integer or floating-point values or tosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values, but got 'tensor'}} %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor, tensor<1xi8>) -> tensor<1x27x27x16xi32> return %0 : tensor<1x27x27x16xi32> diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 591a3f0acf65d..edfe95614abdc 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -333,6 +333,17 @@ func.func @test_dynamic_mixed_matmul(%arg0 : tensor, %arg1 : tensor, %arg1: tensor<1x3x4xf32>, %zero_point: tensor<1xf32>) -> tensor<1x2x4xf32> { + // CHECK: %[[ZP:.*]] = tosa.cast %arg2 : (tensor<1xf32>) -> tensor<1xf32> + %zero_point_unranked = "tosa.cast"(%zero_point) : (tensor<1xf32>) -> tensor<*xf32> + // CHECK: tosa.matmul %arg0, %arg1, %[[ZP]], %[[ZP]] : (tensor<1x2x3xf32>, tensor<1x3x4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x4xf32> + %0 = tosa.matmul %arg0, %arg1, %zero_point_unranked, %zero_point_unranked : (tensor<1x2x3xf32>, tensor<1x3x4xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x2x4xf32> + return %0 : tensor<1x2x4xf32> +} + +// ----- + // CHECK-LABEL: @test_table_static func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () { // CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16>