|
16 | 16 | ) |
17 | 17 | from executorch.backends.arm.operators.operator_validation_utils import ( |
18 | 18 | validate_num_inputs, |
| 19 | + validate_same_dtype, |
19 | 20 | ) |
20 | 21 | from executorch.backends.arm.tosa_mapping import TosaArg |
21 | 22 | from executorch.backends.arm.tosa_specification import TosaSpecification |
@@ -44,14 +45,8 @@ def define_node( |
44 | 45 | import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore |
45 | 46 |
|
46 | 47 | validate_num_inputs(self.target, inputs, 2) |
47 | | - # Specification (0.80) states that input and output types |
48 | | - # should all be the same |
49 | | - if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: |
50 | | - raise TypeError( |
51 | | - f"All IO needs to have the same data type, got input 1: " |
52 | | - f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " |
53 | | - f"{output.dtype}" |
54 | | - ) |
| 48 | + validate_same_dtype(self.target, [*inputs, output]) |
| 49 | + |
55 | 50 | # Handle int8 (quantized) and int32 |
56 | 51 | supported_dtypes = [ts.DType.INT8, ts.DType.INT32] |
57 | 52 | if inputs[0].dtype not in supported_dtypes: |
@@ -123,14 +118,7 @@ def define_node( |
123 | 118 | import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore |
124 | 119 |
|
125 | 120 | validate_num_inputs(self.target, inputs, 2) |
126 | | - # Specification (0.80) states that input and output types |
127 | | - # should all be the same |
128 | | - if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: |
129 | | - raise TypeError( |
130 | | - f"All IO needs to have the same data type, got input 1: " |
131 | | - f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " |
132 | | - f"{output.dtype}" |
133 | | - ) |
| 121 | + validate_same_dtype(self.target, [*inputs, output]) |
134 | 122 |
|
135 | 123 | if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: |
136 | 124 | # Call the inherited define_node for handling integers |
@@ -175,15 +163,8 @@ def define_node( |
175 | 163 | import serializer.tosa_serializer as ts # type: ignore |
176 | 164 |
|
177 | 165 | validate_num_inputs(self.target, inputs, 2) |
| 166 | + validate_same_dtype(self.target, [*inputs, output]) |
178 | 167 |
|
179 | | - # Specification (1.0) states that input and output types |
180 | | - # should all be the same |
181 | | - if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: |
182 | | - raise TypeError( |
183 | | - f"All IO needs to have the same data type, got input 1: " |
184 | | - f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " |
185 | | - f"{output.dtype}" |
186 | | - ) |
187 | 168 | # Handle int8 (quantized) and int32 |
188 | 169 | supported_dtypes = [ts.DType.INT8, ts.DType.INT32] |
189 | 170 | if inputs[0].dtype not in supported_dtypes: |
@@ -245,15 +226,7 @@ def define_node( |
245 | 226 | import serializer.tosa_serializer as ts # type: ignore |
246 | 227 |
|
247 | 228 | validate_num_inputs(self.target, inputs, 2) |
248 | | - |
249 | | - # Specification (1.0) states that input and output types |
250 | | - # should all be the same |
251 | | - if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: |
252 | | - raise TypeError( |
253 | | - f"All IO needs to have the same data type, got input 1: " |
254 | | - f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " |
255 | | - f"{output.dtype}" |
256 | | - ) |
| 229 | + validate_same_dtype(self.target, [*inputs, output]) |
257 | 230 |
|
258 | 231 | if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: |
259 | 232 | # Call the inherited define_node for handling integers |
|
0 commit comments