|
16 | 16 | )
|
17 | 17 | from executorch.backends.arm.operators.operator_validation_utils import (
|
18 | 18 | validate_num_inputs,
|
| 19 | + validate_same_dtype, |
19 | 20 | )
|
20 | 21 | from executorch.backends.arm.tosa_mapping import TosaArg
|
21 | 22 | from executorch.backends.arm.tosa_specification import TosaSpecification
|
@@ -44,14 +45,8 @@ def define_node(
|
44 | 45 | import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
|
45 | 46 |
|
46 | 47 | validate_num_inputs(self.target, inputs, 2)
|
47 |
| - # Specification (0.80) states that input and output types |
48 |
| - # should all be the same |
49 |
| - if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: |
50 |
| - raise TypeError( |
51 |
| - f"All IO needs to have the same data type, got input 1: " |
52 |
| - f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " |
53 |
| - f"{output.dtype}" |
54 |
| - ) |
| 48 | + validate_same_dtype(self.target, [*inputs, output]) |
| 49 | + |
55 | 50 | # Handle int8 (quantized) and int32
|
56 | 51 | supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
|
57 | 52 | if inputs[0].dtype not in supported_dtypes:
|
@@ -123,14 +118,7 @@ def define_node(
|
123 | 118 | import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
|
124 | 119 |
|
125 | 120 | validate_num_inputs(self.target, inputs, 2)
|
126 |
| - # Specification (0.80) states that input and output types |
127 |
| - # should all be the same |
128 |
| - if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: |
129 |
| - raise TypeError( |
130 |
| - f"All IO needs to have the same data type, got input 1: " |
131 |
| - f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " |
132 |
| - f"{output.dtype}" |
133 |
| - ) |
| 121 | + validate_same_dtype(self.target, [*inputs, output]) |
134 | 122 |
|
135 | 123 | if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
|
136 | 124 | # Call the inherited define_node for handling integers
|
@@ -175,15 +163,8 @@ def define_node(
|
175 | 163 | import serializer.tosa_serializer as ts # type: ignore
|
176 | 164 |
|
177 | 165 | validate_num_inputs(self.target, inputs, 2)
|
| 166 | + validate_same_dtype(self.target, [*inputs, output]) |
178 | 167 |
|
179 |
| - # Specification (1.0) states that input and output types |
180 |
| - # should all be the same |
181 |
| - if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: |
182 |
| - raise TypeError( |
183 |
| - f"All IO needs to have the same data type, got input 1: " |
184 |
| - f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " |
185 |
| - f"{output.dtype}" |
186 |
| - ) |
187 | 168 | # Handle int8 (quantized) and int32
|
188 | 169 | supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
|
189 | 170 | if inputs[0].dtype not in supported_dtypes:
|
@@ -245,15 +226,7 @@ def define_node(
|
245 | 226 | import serializer.tosa_serializer as ts # type: ignore
|
246 | 227 |
|
247 | 228 | validate_num_inputs(self.target, inputs, 2)
|
248 |
| - |
249 |
| - # Specification (1.0) states that input and output types |
250 |
| - # should all be the same |
251 |
| - if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: |
252 |
| - raise TypeError( |
253 |
| - f"All IO needs to have the same data type, got input 1: " |
254 |
| - f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " |
255 |
| - f"{output.dtype}" |
256 |
| - ) |
| 229 | + validate_same_dtype(self.target, [*inputs, output]) |
257 | 230 |
|
258 | 231 | if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
|
259 | 232 | # Call the inherited define_node for handling integers
|
|
0 commit comments