diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index b79993f48b379..0c99dd6130c2a 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -41,7 +41,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> { }]; let arguments = (ins - Tosa_Tensor: $input, + Tosa_TensorAtLeast1D: $input, I32Attr: $axis, DefaultValuedAttr:$nan_mode ); @@ -1629,12 +1629,12 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> { }]; let arguments = (ins - Tosa_Tensor:$input, + Tosa_TensorAtLeast1D:$input, I32Attr:$axis ); let results = (outs - Tosa_Tensor:$output + Tosa_TensorAtLeast1D:$output ); list availability = [ @@ -1668,12 +1668,12 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> { }]; let arguments = (ins - Tosa_Tensor:$input, + Tosa_TensorAtLeast1D:$input, I32Attr:$axis ); let results = (outs - Tosa_Tensor:$output + Tosa_TensorAtLeast1D:$output ); list availability = [ @@ -1707,13 +1707,13 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> { }]; let arguments = (ins - Tosa_Tensor:$input, + Tosa_TensorAtLeast1D:$input, I32Attr:$axis, DefaultValuedAttr:$nan_mode ); let results = (outs - Tosa_Tensor:$output + Tosa_TensorAtLeast1D:$output ); list availability = [ @@ -1748,13 +1748,13 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> { }]; let arguments = (ins - Tosa_Tensor:$input, + Tosa_TensorAtLeast1D:$input, I32Attr:$axis, DefaultValuedAttr:$nan_mode ); let results = (outs - Tosa_Tensor:$output + Tosa_TensorAtLeast1D:$output ); list availability = [ @@ -1789,12 +1789,12 @@ def Tosa_ReduceProductOp : Tosa_InferTensorTypeOp<"reduce_product"> { }]; let arguments = (ins - Tosa_Tensor:$input, + Tosa_TensorAtLeast1D:$input, I32Attr:$axis ); let results = (outs - Tosa_Tensor:$output + Tosa_TensorAtLeast1D:$output ); list availability = [ @@ -1828,12 +1828,12 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> { }]; let arguments = (ins - Tosa_Tensor:$input, + Tosa_TensorAtLeast1D:$input, I32Attr:$axis ); let results = (outs - Tosa_Tensor:$output + Tosa_TensorAtLeast1D:$output ); list availability = [ @@ -1872,12 +1872,12 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> { }]; let arguments = (ins - Variadic:$input1, + Variadic:$input1, I32Attr:$axis ); let results = (outs - Tosa_Tensor:$output + Tosa_TensorAtLeast1D:$output ); list availability = [ @@ -1923,13 +1923,13 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> { }]; let arguments = (ins - Tosa_RankedTensor:$input1, + Tosa_TensorAtLeast1D:$input1, Tosa_Shape:$padding, Tosa_ScalarTensor:$pad_const ); let results = (outs - Tosa_RankedTensor:$output + Tosa_TensorAtLeast1D:$output ); list availability = [ @@ -1996,12 +1996,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [ }]; let arguments = (ins - Tosa_Tensor:$input1, + Tosa_TensorAtLeast1D:$input1, I32Attr:$axis ); let results = (outs - Tosa_Tensor:$output + Tosa_TensorAtLeast1D:$output ); list availability = [ @@ -2028,13 +2028,13 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> { }]; let arguments = (ins - Tosa_Tensor:$input1, + Tosa_TensorAtLeast1D:$input1, Tosa_Shape:$start, Tosa_Shape:$size ); let results = (outs - Tosa_Tensor:$output + Tosa_TensorAtLeast1D:$output ); list availability = [ @@ -2058,11 +2058,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> { }]; let arguments = (ins - Tosa_Tensor:$input1, + Tosa_TensorAtLeast1D:$input1, Tosa_Shape:$multiples); let results = (outs - Tosa_Tensor:$output + Tosa_TensorAtLeast1D:$output ); list availability = [ @@ -2093,12 +2093,12 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose", }]; let arguments = (ins - Tosa_Tensor:$input1, + Tosa_TensorAtLeast1D:$input1, DenseI32ArrayAttr:$perms ); let results = ( - outs Tosa_Tensor:$output + outs Tosa_TensorAtLeast1D:$output ); list availability = [ diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 0038d8c386ca7..67011f22fbe2a 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -101,6 +101,10 @@ def AllDimensionsAreSizeOne : And<[ IsRankedTensorTypePred, CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>; +def AtLeastRankOne : And<[ + IsRankedTensorTypePred, + CPred<"::llvm::cast<::mlir::RankedTensorType>($_self).getRank() >= 1">]>; + class TosaTensorOf< list allowedTypes, string summary = "tosa-conformant tensor"> : TensorOf], summary>; @@ -183,6 +187,9 @@ def Tosa_TensorUpto4D : AnyTypeOf<[ def Tosa_Int32TensorUpto4D : AnyTypeOf<[ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>; +def Tosa_TensorAtLeast1D : AnyTypeOf<[ + Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">; + //===----------------------------------------------------------------------===// // Generic scalar, vector, or tensor of a particular type. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 97a3009a20302..cdba332792eb0 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1354,8 +1354,13 @@ LogicalResult tosa::PadOp::verify() { } } - RankedTensorType inputType = getInput1().getType(); - RankedTensorType outputType = getOutput().getType(); + RankedTensorType inputType = + llvm::dyn_cast(getInput1().getType()); + RankedTensorType outputType = + llvm::dyn_cast(getOutput().getType()); + if (!inputType || !outputType) + return success(); + auto paddingRank = cast(getPadding().getType()).getRank(); if (inputType.getRank() != outputType.getRank()) diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 3bc438e465e1d..077a6cee0a1bb 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -915,29 +915,6 @@ func.func @fold_abs_abs(%arg0: tensor) -> tensor { // ----- -// CHECK-LABEL: @fold_reduce_rank_zero -func.func @fold_reduce_rank_zero() { - // CHECK-NOT: tosa.reduce_min - // CHECK-NOT: tosa.reverse - %0 = tensor.empty() : tensor - %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor) -> tensor - %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor) -> tensor - return -} - -// ----- - -// CHECK-LABEL: @fold_tile_rank_zero -func.func nested @fold_tile_rank_zero() -> tensor { - // CHECK-NOT: tosa.tile - %0 = tensor.empty() : tensor - %cst = tosa.const_shape { values = dense<> : tensor<0xindex> } : () -> !tosa.shape<0> - %1 = tosa.tile %0, %cst : (tensor, !tosa.shape<0>) -> tensor - return %1 : tensor -} - -// ----- - // CHECK-LABEL: @reshape_quant_nofold // check that segfault is fixed func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> { @@ -1015,12 +992,12 @@ func.func @cast_quant_nofold() -> tensor tensor> { - // CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<0> : tensor}> : () -> tensor> +func.func @reverse_quant_fold() -> tensor<1x!quant.uniform> { + // CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform> // CHECK: return %[[CST]] - %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor> - %1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor>) -> tensor> - return %1 : tensor> + %0 = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform> + %1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor<1x!quant.uniform>) -> tensor<1x!quant.uniform> + return %1 : tensor<1x!quant.uniform> } // ----- diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index a488c051dcd3b..2dc749422c12d 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -452,9 +452,9 @@ func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () { // ----- -func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor) -> () { +func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<1xi32>) -> () { // expected-error@+1 {{'tosa.reduce_min' op expect output tensor rank to be equal to input tensor rank}} - %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor) -> tensor<1x10xi32> + %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<1xi32>) -> tensor<1x10xi32> return } @@ -1852,3 +1852,72 @@ func.func @test_maxpool2d_unexpected_output_width(%arg0: tensor<1x32x32x8xf32>) (tensor<1x32x32x8xf32>) -> tensor<1x32x2x8xf32> return %0 : tensor<1x32x2x8xf32> } + +// ----- + +func.func @test_scalar_argmax(%arg0: tensor) -> tensor { + // expected-error@+1 {{'tosa.argmax' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +func.func @test_scalar_reduce_all(%arg0: tensor) -> tensor { + // expected-error@+1 {{'tosa.reduce_all' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +func.func @test_scalar_inputs_concat(%arg0: tensor, %arg1: tensor) -> tensor<2xf32> { + // expected-error@+1 {{'tosa.concat' op operand #0 must be variadic of tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor, tensor) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func.func @test_scalar_pad(%arg0: tensor) -> tensor { + %0 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32> + %padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6> + // expected-error@+1 {{'tosa.pad' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %1 = tosa.pad %arg0, %padding, %0 : (tensor, !tosa.shape<6>, tensor<1xf32>) -> tensor + return %1 : tensor +} + +// ----- + +func.func @test_scalar_reverse(%arg0: tensor) -> tensor { + // expected-error@+1 {{'tosa.reverse' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %0 = tosa.reverse %arg0 {axis = 0: i32} : (tensor) -> tensor + return %arg0 : tensor +} + +// ----- + +func.func @test_scalar_slice(%arg0: tensor) -> tensor { + %0 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0> + %1 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0> + // expected-error@+1 {{'tosa.slice' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %2 = tosa.slice %arg0, %0, %1 : (tensor, !tosa.shape<0>, !tosa.shape<0>) -> tensor + return %2 : tensor +} + +// ----- + +func.func @test_scalar_tile(%arg0: tensor) -> tensor<*xf32> { + %cst = tosa.const_shape { values = dense<[]> : tensor<0xindex> } : () -> !tosa.shape<0> + // expected-error@+1 {{'tosa.tile' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %0 = tosa.tile %arg0, %cst: (tensor, !tosa.shape<0>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor { + // expected-error@+1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor'}} + %1 = tosa.transpose %arg0 {perms = array} : (tensor<*xf32>) -> tensor + return %1 : tensor +}