1515 NodeVisitor ,
1616 register_node_visitor ,
1717)
18+ from executorch .backends .arm .operators .operator_validation_utils import (
19+ validate_num_inputs ,
20+ )
1821
1922from executorch .backends .arm .tosa_mapping import TosaArg
2023from executorch .backends .arm .tosa_specification import TosaSpecification
@@ -65,9 +68,6 @@ def cast_type(value: Any) -> int | float:
6568 # Attempt to cast to float
6669 return float (value )
6770
68- if len (node .args ) != 2 and len (node .args ) != 3 :
69- raise ValueError (f"Expected len(node.args) to be 2 or 3, got { node .args } " )
70-
7171 min_arg = dtype_min
7272 max_arg = dtype_max
7373
@@ -87,10 +87,7 @@ def define_node(
8787 inputs : List [TosaArg ],
8888 output : TosaArg ,
8989 ) -> None :
90- if len (node .all_input_nodes ) != 1 :
91- raise ValueError (
92- f"Expected 1 input for { self .target } , got { len (node .all_input_nodes )} "
93- )
90+ validate_num_inputs (self .target , inputs , [2 , 3 ])
9491
9592 min_int8 , max_int8 = self ._get_min_max_arguments (
9693 node ,
@@ -130,10 +127,7 @@ def define_node(
130127 ) -> None :
131128 import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
132129
133- if len (node .all_input_nodes ) != 1 :
134- raise ValueError (
135- f"Expected 1 input for { self .target } , got { len (node .all_input_nodes )} "
136- )
130+ validate_num_inputs (self .target , inputs , [2 , 3 ])
137131
138132 if inputs [0 ].dtype == ts .DType .INT8 :
139133 # Call the inherited define_node for handling integers
@@ -178,9 +172,6 @@ def cast_type(value: Any) -> int | float:
178172 # Attempt to cast to float
179173 return float (value )
180174
181- if len (node .args ) != 2 and len (node .args ) != 3 :
182- raise ValueError (f"Expected len(node.args) to be 2 or 3, got { node .args } " )
183-
184175 min_arg = dtype_min
185176 max_arg = dtype_max
186177
@@ -202,10 +193,7 @@ def define_node(
202193 ) -> None :
203194 import serializer .tosa_serializer as ts # type: ignore
204195
205- if len (node .all_input_nodes ) != 1 :
206- raise ValueError (
207- f"Expected 1 input for { self .target } , got { len (node .all_input_nodes )} "
208- )
196+ validate_num_inputs (self .target , inputs , [2 , 3 ])
209197
210198 # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
211199 min_int8 , max_int8 = self ._get_min_max_arguments (
@@ -247,10 +235,7 @@ def define_node(
247235 ) -> None :
248236 import serializer .tosa_serializer as ts # type: ignore
249237
250- if len (node .all_input_nodes ) != 1 :
251- raise ValueError (
252- f"Expected 1 input for { self .target } , got { len (node .all_input_nodes )} "
253- )
238+ validate_num_inputs (self .target , inputs , [2 , 3 ])
254239
255240 min_fp32 , max_fp32 = self ._get_min_max_arguments (
256241 node ,
0 commit comments