Skip to content

Commit 422720a

Browse files
committed
Allow operands and results of tosa.custom to be None. This enables lowering from dialects that support None operands or results
1 parent ee40aef commit 422720a

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,11 +2057,11 @@ def Tosa_CustomOp : Tosa_Op<"custom"> {
20572057
StrAttr:$operator_name,
20582058
StrAttr:$domain_name,
20592059
StrAttr:$implementation_attrs,
2060-
Variadic<Tosa_Tensor>:$input_list
2060+
Variadic<Tosa_TensorOrNone>:$input_list
20612061
);
20622062

20632063
let results = (outs
2064-
Variadic<Tosa_Tensor>:$output_list
2064+
Variadic<Tosa_TensorOrNone>:$output_list
20652065
);
20662066

20672067
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
135135
class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
136136
AnyTypeOf<[TosaTensorOf<allowedTypes>, NoneType], description>;
137137

138+
def Tosa_TensorOrNone : Tosa_TensorOfOrNone<[Tosa_AnyNumber]>;
139+
138140
//===----------------------------------------------------------------------===//
139141
// Tensor types with constrained ranks.
140142
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,13 @@ func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> {
694694
return %0 : tensor<10xi32>
695695
}
696696

697+
// -----
698+
// CHECK-LABEL: test_custom_none
699+
func.func @test_custom_none(%arg0: tensor<10xi32>, %arg1: none) -> tensor<10xi32> {
700+
%0 = tosa.custom %arg0, %arg1 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>, none) -> (tensor<10xi32>)
701+
return %0 : tensor<10xi32>
702+
}
703+
697704
// -----
698705
// CHECK-LABEL: const_shape
699706
func.func @test_const_shape() -> !tosa.shape<4> {

0 commit comments

Comments
 (0)