diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index ccdbb60a1fe9a..3c4aabb9768d1 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -2057,11 +2057,11 @@ def Tosa_CustomOp : Tosa_Op<"custom"> { StrAttr:$operator_name, StrAttr:$domain_name, StrAttr:$implementation_attrs, - Variadic:$input_list + Variadic:$input_list ); let results = (outs - Variadic:$output_list + Variadic:$output_list ); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 5ca7720508d54..99cd3a0c864de 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -135,6 +135,8 @@ def Tosa_ElementType : Type allowedTypes, string description = ""> : AnyTypeOf<[TosaTensorOf, NoneType], description>; +def Tosa_TensorOrNone : Tosa_TensorOfOrNone<[Tosa_AnyNumber]>; + //===----------------------------------------------------------------------===// // Tensor types with constrained ranks. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 690e208af1e5f..1101475321d49 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -694,6 +694,13 @@ func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> { return %0 : tensor<10xi32> } +// ----- +// CHECK-LABEL: test_custom_none +func.func @test_custom_none(%arg0: tensor<10xi32>, %arg1: none) -> tensor<10xi32> { + %0 = tosa.custom %arg0, %arg1 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>, none) -> (tensor<10xi32>) + return %0 : tensor<10xi32> +} + // ----- // CHECK-LABEL: const_shape func.func @test_const_shape() -> !tosa.shape<4> {