|
20 | 20 | adjust_pooling_pad_if_needed, |
21 | 21 | validate_num_inputs, |
22 | 22 | validate_same_dtype, |
| 23 | + validate_valid_dtype, |
23 | 24 | ) |
24 | 25 | from executorch.backends.arm.tosa_mapping import TosaArg |
25 | 26 | from executorch.backends.arm.tosa_specification import TosaSpecification |
@@ -106,13 +107,9 @@ def define_node( |
106 | 107 |
|
107 | 108 | validate_num_inputs(self.target, inputs, [3, 4, 6]) |
108 | 109 | validate_same_dtype(self.target, [inputs[0], output], ts) |
109 | | - |
110 | | - supported_dtypes = [ts.DType.INT8] |
111 | | - if inputs[0].dtype not in supported_dtypes: |
112 | | - raise TypeError( |
113 | | - f"IO data type needs to be one of {supported_dtypes}, got " |
114 | | - f'"{inputs[0].dtype}"' |
115 | | - ) |
| 110 | + validate_valid_dtype( |
| 111 | + self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec |
| 112 | + ) |
116 | 113 |
|
117 | 114 | accumulator_type = ts.DType.INT32 |
118 | 115 |
|
@@ -146,13 +143,12 @@ def define_node( |
146 | 143 |
|
147 | 144 | validate_num_inputs(self.target, inputs, [3, 4, 6]) |
148 | 145 | validate_same_dtype(self.target, [inputs[0], output], ts) |
149 | | - |
150 | | - supported_dtypes = [ts.DType.INT8, ts.DType.FP32] |
151 | | - if inputs[0].dtype not in supported_dtypes: |
152 | | - raise TypeError( |
153 | | - f"IO data type needs to be one of {supported_dtypes}, got " |
154 | | - f'"{inputs[0].dtype}"' |
155 | | - ) |
| 146 | + validate_valid_dtype( |
| 147 | + self.target, |
| 148 | + [inputs[0], output], |
| 149 | + [ts.DType.INT8, ts.DType.FP32], |
| 150 | + output.tosa_spec, |
| 151 | + ) |
156 | 152 |
|
157 | 153 | if inputs[0].dtype == ts.DType.INT8: |
158 | 154 | super().define_node(node, tosa_graph, inputs, output) |
@@ -253,13 +249,9 @@ def define_node( |
253 | 249 |
|
254 | 250 | validate_num_inputs(self.target, inputs, [3, 4, 6]) |
255 | 251 | validate_same_dtype(self.target, [inputs[0], output], ts) |
256 | | - |
257 | | - supported_dtypes = [ts.DType.INT8] |
258 | | - if inputs[0].dtype not in supported_dtypes: |
259 | | - raise TypeError( |
260 | | - f"IO data type needs to be one of {supported_dtypes}, got " |
261 | | - f'"{inputs[0].dtype}"' |
262 | | - ) |
| 252 | + validate_valid_dtype( |
| 253 | + self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec |
| 254 | + ) |
263 | 255 |
|
264 | 256 | accumulator_type = ts.DType.INT32 |
265 | 257 |
|
@@ -296,13 +288,12 @@ def define_node( |
296 | 288 |
|
297 | 289 | validate_num_inputs(self.target, inputs, [3, 4, 6]) |
298 | 290 | validate_same_dtype(self.target, [inputs[0], output], ts) |
299 | | - |
300 | | - supported_dtypes = [ts.DType.INT8, ts.DType.FP32] |
301 | | - if inputs[0].dtype not in supported_dtypes: |
302 | | - raise TypeError( |
303 | | - f"IO data type needs to be one of {supported_dtypes}, got " |
304 | | - f'"{inputs[0].dtype}"' |
305 | | - ) |
| 291 | + validate_valid_dtype( |
| 292 | + self.target, |
| 293 | + [inputs[0], output], |
| 294 | + [ts.DType.INT8, ts.DType.FP32], |
| 295 | + output.tosa_spec, |
| 296 | + ) |
306 | 297 |
|
307 | 298 | if inputs[0].dtype == ts.DType.INT8: |
308 | 299 | super().define_node(node, tosa_graph, inputs, output) |
|
0 commit comments