|
20 | 20 | adjust_pooling_pad_if_needed,
|
21 | 21 | validate_num_inputs,
|
22 | 22 | validate_same_dtype,
|
| 23 | + validate_valid_dtype, |
23 | 24 | )
|
24 | 25 | from executorch.backends.arm.tosa_mapping import TosaArg
|
25 | 26 | from executorch.backends.arm.tosa_specification import TosaSpecification
|
@@ -106,13 +107,9 @@ def define_node(
|
106 | 107 |
|
107 | 108 | validate_num_inputs(self.target, inputs, [3, 4, 6])
|
108 | 109 | validate_same_dtype(self.target, [inputs[0], output], ts)
|
109 |
| - |
110 |
| - supported_dtypes = [ts.DType.INT8] |
111 |
| - if inputs[0].dtype not in supported_dtypes: |
112 |
| - raise TypeError( |
113 |
| - f"IO data type needs to be one of {supported_dtypes}, got " |
114 |
| - f'"{inputs[0].dtype}"' |
115 |
| - ) |
| 110 | + validate_valid_dtype( |
| 111 | + self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec |
| 112 | + ) |
116 | 113 |
|
117 | 114 | accumulator_type = ts.DType.INT32
|
118 | 115 |
|
@@ -146,13 +143,12 @@ def define_node(
|
146 | 143 |
|
147 | 144 | validate_num_inputs(self.target, inputs, [3, 4, 6])
|
148 | 145 | validate_same_dtype(self.target, [inputs[0], output], ts)
|
149 |
| - |
150 |
| - supported_dtypes = [ts.DType.INT8, ts.DType.FP32] |
151 |
| - if inputs[0].dtype not in supported_dtypes: |
152 |
| - raise TypeError( |
153 |
| - f"IO data type needs to be one of {supported_dtypes}, got " |
154 |
| - f'"{inputs[0].dtype}"' |
155 |
| - ) |
| 146 | + validate_valid_dtype( |
| 147 | + self.target, |
| 148 | + [inputs[0], output], |
| 149 | + [ts.DType.INT8, ts.DType.FP32], |
| 150 | + output.tosa_spec, |
| 151 | + ) |
156 | 152 |
|
157 | 153 | if inputs[0].dtype == ts.DType.INT8:
|
158 | 154 | super().define_node(node, tosa_graph, inputs, output)
|
@@ -253,13 +249,9 @@ def define_node(
|
253 | 249 |
|
254 | 250 | validate_num_inputs(self.target, inputs, [3, 4, 6])
|
255 | 251 | validate_same_dtype(self.target, [inputs[0], output], ts)
|
256 |
| - |
257 |
| - supported_dtypes = [ts.DType.INT8] |
258 |
| - if inputs[0].dtype not in supported_dtypes: |
259 |
| - raise TypeError( |
260 |
| - f"IO data type needs to be one of {supported_dtypes}, got " |
261 |
| - f'"{inputs[0].dtype}"' |
262 |
| - ) |
| 252 | + validate_valid_dtype( |
| 253 | + self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec |
| 254 | + ) |
263 | 255 |
|
264 | 256 | accumulator_type = ts.DType.INT32
|
265 | 257 |
|
@@ -296,13 +288,12 @@ def define_node(
|
296 | 288 |
|
297 | 289 | validate_num_inputs(self.target, inputs, [3, 4, 6])
|
298 | 290 | validate_same_dtype(self.target, [inputs[0], output], ts)
|
299 |
| - |
300 |
| - supported_dtypes = [ts.DType.INT8, ts.DType.FP32] |
301 |
| - if inputs[0].dtype not in supported_dtypes: |
302 |
| - raise TypeError( |
303 |
| - f"IO data type needs to be one of {supported_dtypes}, got " |
304 |
| - f'"{inputs[0].dtype}"' |
305 |
| - ) |
| 291 | + validate_valid_dtype( |
| 292 | + self.target, |
| 293 | + [inputs[0], output], |
| 294 | + [ts.DType.INT8, ts.DType.FP32], |
| 295 | + output.tosa_spec, |
| 296 | + ) |
306 | 297 |
|
307 | 298 | if inputs[0].dtype == ts.DType.INT8:
|
308 | 299 | super().define_node(node, tosa_graph, inputs, output)
|
|
0 commit comments