15
15
NodeVisitor ,
16
16
register_node_visitor ,
17
17
)
18
+ from executorch .backends .arm .operators .operator_validation_utils import (
19
+ validate_num_inputs ,
20
+ )
18
21
19
22
from executorch .backends .arm .tosa_mapping import TosaArg
20
23
from executorch .backends .arm .tosa_specification import TosaSpecification
@@ -65,9 +68,6 @@ def cast_type(value: Any) -> int | float:
65
68
# Attempt to cast to float
66
69
return float (value )
67
70
68
- if len (node .args ) != 2 and len (node .args ) != 3 :
69
- raise ValueError (f"Expected len(node.args) to be 2 or 3, got { node .args } " )
70
-
71
71
min_arg = dtype_min
72
72
max_arg = dtype_max
73
73
@@ -87,10 +87,7 @@ def define_node(
87
87
inputs : List [TosaArg ],
88
88
output : TosaArg ,
89
89
) -> None :
90
- if len (node .all_input_nodes ) != 1 :
91
- raise ValueError (
92
- f"Expected 1 input for { self .target } , got { len (node .all_input_nodes )} "
93
- )
90
+ validate_num_inputs (self .target , inputs , [2 , 3 ])
94
91
95
92
min_int8 , max_int8 = self ._get_min_max_arguments (
96
93
node ,
@@ -130,10 +127,7 @@ def define_node(
130
127
) -> None :
131
128
import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
132
129
133
- if len (node .all_input_nodes ) != 1 :
134
- raise ValueError (
135
- f"Expected 1 input for { self .target } , got { len (node .all_input_nodes )} "
136
- )
130
+ validate_num_inputs (self .target , inputs , [2 , 3 ])
137
131
138
132
if inputs [0 ].dtype == ts .DType .INT8 :
139
133
# Call the inherited define_node for handling integers
@@ -178,9 +172,6 @@ def cast_type(value: Any) -> int | float:
178
172
# Attempt to cast to float
179
173
return float (value )
180
174
181
- if len (node .args ) != 2 and len (node .args ) != 3 :
182
- raise ValueError (f"Expected len(node.args) to be 2 or 3, got { node .args } " )
183
-
184
175
min_arg = dtype_min
185
176
max_arg = dtype_max
186
177
@@ -202,10 +193,7 @@ def define_node(
202
193
) -> None :
203
194
import serializer .tosa_serializer as ts # type: ignore
204
195
205
- if len (node .all_input_nodes ) != 1 :
206
- raise ValueError (
207
- f"Expected 1 input for { self .target } , got { len (node .all_input_nodes )} "
208
- )
196
+ validate_num_inputs (self .target , inputs , [2 , 3 ])
209
197
210
198
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
211
199
min_int8 , max_int8 = self ._get_min_max_arguments (
@@ -247,10 +235,7 @@ def define_node(
247
235
) -> None :
248
236
import serializer .tosa_serializer as ts # type: ignore
249
237
250
- if len (node .all_input_nodes ) != 1 :
251
- raise ValueError (
252
- f"Expected 1 input for { self .target } , got { len (node .all_input_nodes )} "
253
- )
238
+ validate_num_inputs (self .target , inputs , [2 , 3 ])
254
239
255
240
min_fp32 , max_fp32 = self ._get_min_max_arguments (
256
241
node ,
0 commit comments