Skip to content

Commit d9b3cf4

Browse files
Arm backend: Add validate_valid_dtype (#11472)
Add a new validate_valid_dtype helper in operator_validation_utils to centralize and standardize dtype validation logic across all ARM backend operators. Manual, duplicated dtype checks in individual operator visitors have been replaced with calls to validate_valid_dtype. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Sebastian Larsson <[email protected]>
1 parent d660bde commit d9b3cf4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+600
-298
lines changed

backends/arm/operators/op_abs.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.backends.arm.operators.operator_validation_utils import (
1717
validate_num_inputs,
1818
validate_same_dtype,
19+
validate_valid_dtype,
1920
)
2021
from executorch.backends.arm.tosa_mapping import TosaArg
2122
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -45,12 +46,13 @@ def define_node(
4546

4647
validate_num_inputs(self.target, inputs, 1)
4748
validate_same_dtype(self.target, [*inputs, output], ts)
48-
4949
# Handle int8 (quantized) and int32
50-
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
51-
raise ValueError(
52-
"All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}"
53-
)
50+
validate_valid_dtype(
51+
self.target,
52+
[*inputs, output],
53+
[ts.DType.INT8, ts.DType.INT32],
54+
output.tosa_spec,
55+
)
5456

5557
if inputs[0].dtype == ts.DType.INT8:
5658
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
@@ -113,14 +115,9 @@ def define_node(
113115
super().define_node(node, tosa_graph, inputs, output)
114116
else:
115117
# FP32 Abs lowering
116-
117-
if not (inputs[0].dtype == ts.DType.FP32):
118-
raise ValueError(
119-
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
120-
)
121-
122-
if not (output.dtype == ts.DType.FP32):
123-
raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}")
118+
validate_valid_dtype(
119+
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
120+
)
124121

125122
# MI lowering
126123
tosa_graph.addOperator(
@@ -156,10 +153,12 @@ def define_node(
156153
validate_same_dtype(self.target, [*inputs, output], ts)
157154

158155
# Handle int8 (quantized) and int32
159-
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
160-
raise ValueError(
161-
"All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}"
162-
)
156+
validate_valid_dtype(
157+
self.target,
158+
[*inputs, output],
159+
[ts.DType.INT8, ts.DType.INT32],
160+
output.tosa_spec,
161+
)
163162

164163
scale_back = 1.0
165164
if inputs[0].dtype == ts.DType.INT8:
@@ -224,13 +223,9 @@ def define_node(
224223
else:
225224
# FP32 Abs lowering
226225

227-
if not (inputs[0].dtype == ts.DType.FP32):
228-
raise ValueError(
229-
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
230-
)
231-
232-
if not (output.dtype == ts.DType.FP32):
233-
raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}")
226+
validate_valid_dtype(
227+
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
228+
)
234229

235230
# MI lowering
236231
tosa_graph.addOperator(

backends/arm/operators/op_add.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.backends.arm.operators.operator_validation_utils import (
1818
validate_num_inputs,
1919
validate_same_dtype,
20+
validate_valid_dtype,
2021
)
2122
from executorch.backends.arm.tosa_mapping import TosaArg
2223
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -46,13 +47,12 @@ def define_node(
4647

4748
validate_num_inputs(self.target, inputs, 2)
4849
validate_same_dtype(self.target, [*inputs, output], ts)
49-
50-
# Handle int8 (quantized) and int32
51-
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
52-
if inputs[0].dtype not in supported_dtypes:
53-
raise TypeError(
54-
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
55-
)
50+
validate_valid_dtype(
51+
self.target,
52+
[*inputs, output],
53+
[ts.DType.INT8, ts.DType.INT32],
54+
output.tosa_spec,
55+
)
5656

5757
dim_order = (
5858
inputs[0].dim_order
@@ -125,10 +125,9 @@ def define_node(
125125
super().define_node(node, tosa_graph, inputs, output)
126126
else:
127127
# FP32 Add lowering
128-
if inputs[0].dtype != ts.DType.FP32:
129-
raise TypeError(
130-
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
131-
)
128+
validate_valid_dtype(
129+
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
130+
)
132131

133132
input1, input2 = inputs
134133

@@ -164,13 +163,13 @@ def define_node(
164163

165164
validate_num_inputs(self.target, inputs, 2)
166165
validate_same_dtype(self.target, [*inputs, output], ts)
166+
validate_valid_dtype(
167+
self.target,
168+
[*inputs, output],
169+
[ts.DType.INT8, ts.DType.INT32],
170+
output.tosa_spec,
171+
)
167172

168-
# Handle int8 (quantized) and int32
169-
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
170-
if inputs[0].dtype not in supported_dtypes:
171-
raise TypeError(
172-
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
173-
)
174173
scale_back = 1.0
175174
if inputs[0].dtype == ts.DType.INT8:
176175
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
@@ -233,10 +232,9 @@ def define_node(
233232
super().define_node(node, tosa_graph, inputs, output)
234233
else:
235234
# FP32 Add lowering
236-
if inputs[0].dtype != ts.DType.FP32:
237-
raise TypeError(
238-
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
239-
)
235+
validate_valid_dtype(
236+
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
237+
)
240238

241239
input1, input2 = inputs
242240

backends/arm/operators/op_amax.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.backends.arm.operators.operator_validation_utils import (
1313
validate_num_inputs,
1414
validate_same_dtype,
15+
validate_valid_dtype,
1516
)
1617
from executorch.backends.arm.tosa_mapping import TosaArg
1718
from torch.fx import Node
@@ -37,6 +38,12 @@ def define_node(
3738

3839
validate_num_inputs(self.target, inputs, 3)
3940
validate_same_dtype(self.target, [inputs[0], output], ts)
41+
validate_valid_dtype(
42+
self.target,
43+
[inputs[0], output],
44+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
45+
output.tosa_spec,
46+
)
4047

4148
input = inputs[0]
4249
dim = inputs[1].number
@@ -80,6 +87,12 @@ def define_node(
8087

8188
validate_num_inputs(self.target, inputs, 3)
8289
validate_same_dtype(self.target, [inputs[0], output], ts)
90+
validate_valid_dtype(
91+
self.target,
92+
[inputs[0], output],
93+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
94+
output.tosa_spec,
95+
)
8396

8497
input = inputs[0]
8598
dim = inputs[1].number

backends/arm/operators/op_amin.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.backends.arm.operators.operator_validation_utils import (
1313
validate_num_inputs,
1414
validate_same_dtype,
15+
validate_valid_dtype,
1516
)
1617
from executorch.backends.arm.tosa_mapping import TosaArg
1718
from torch.fx import Node
@@ -37,6 +38,12 @@ def define_node(
3738

3839
validate_num_inputs(self.target, inputs, 3)
3940
validate_same_dtype(self.target, [inputs[0], output], ts)
41+
validate_valid_dtype(
42+
self.target,
43+
[inputs[0], output],
44+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
45+
output.tosa_spec,
46+
)
4047

4148
input = inputs[0]
4249
dim = inputs[1].number
@@ -80,6 +87,12 @@ def define_node(
8087

8188
validate_num_inputs(self.target, inputs, 3)
8289
validate_same_dtype(self.target, [inputs[0], output], ts)
90+
validate_valid_dtype(
91+
self.target,
92+
[inputs[0], output],
93+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
94+
output.tosa_spec,
95+
)
8396

8497
input = inputs[0]
8598
dim = inputs[1].number

backends/arm/operators/op_any.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.arm.operators.operator_validation_utils import (
1414
validate_num_inputs,
1515
validate_same_dtype,
16+
validate_valid_dtype,
1617
)
1718

1819
from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore
@@ -36,9 +37,9 @@ def define_node(
3637

3738
validate_num_inputs(self.target, inputs, 3)
3839
validate_same_dtype(self.target, [inputs[0], output], ts)
39-
40-
if not (inputs[0].dtype == ts.DType.BOOL):
41-
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")
40+
validate_valid_dtype(
41+
self.target, [inputs[0], output], ts.DType.BOOL, output.tosa_spec
42+
)
4243

4344
input_shape = list(inputs[0].shape)
4445
dim = cast(int, inputs[1].number) % len(
@@ -73,9 +74,9 @@ def define_node(
7374

7475
validate_num_inputs(self.target, inputs, 3)
7576
validate_same_dtype(self.target, [inputs[0], output], ts)
76-
77-
if not (inputs[0].dtype == ts.DType.BOOL):
78-
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")
77+
validate_valid_dtype(
78+
self.target, [inputs[0], output], ts.DType.BOOL, output.tosa_spec
79+
)
7980

8081
input_shape = list(inputs[0].shape)
8182
dim = cast(int, inputs[1].number) % len(

backends/arm/operators/op_avg_pool2d.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
adjust_pooling_pad_if_needed,
2121
validate_num_inputs,
2222
validate_same_dtype,
23+
validate_valid_dtype,
2324
)
2425
from executorch.backends.arm.tosa_mapping import TosaArg
2526
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -106,13 +107,9 @@ def define_node(
106107

107108
validate_num_inputs(self.target, inputs, [3, 4, 6])
108109
validate_same_dtype(self.target, [inputs[0], output], ts)
109-
110-
supported_dtypes = [ts.DType.INT8]
111-
if inputs[0].dtype not in supported_dtypes:
112-
raise TypeError(
113-
f"IO data type needs to be one of {supported_dtypes}, got "
114-
f'"{inputs[0].dtype}"'
115-
)
110+
validate_valid_dtype(
111+
self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec
112+
)
116113

117114
accumulator_type = ts.DType.INT32
118115

@@ -146,13 +143,12 @@ def define_node(
146143

147144
validate_num_inputs(self.target, inputs, [3, 4, 6])
148145
validate_same_dtype(self.target, [inputs[0], output], ts)
149-
150-
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
151-
if inputs[0].dtype not in supported_dtypes:
152-
raise TypeError(
153-
f"IO data type needs to be one of {supported_dtypes}, got "
154-
f'"{inputs[0].dtype}"'
155-
)
146+
validate_valid_dtype(
147+
self.target,
148+
[inputs[0], output],
149+
[ts.DType.INT8, ts.DType.FP32],
150+
output.tosa_spec,
151+
)
156152

157153
if inputs[0].dtype == ts.DType.INT8:
158154
super().define_node(node, tosa_graph, inputs, output)
@@ -253,13 +249,9 @@ def define_node(
253249

254250
validate_num_inputs(self.target, inputs, [3, 4, 6])
255251
validate_same_dtype(self.target, [inputs[0], output], ts)
256-
257-
supported_dtypes = [ts.DType.INT8]
258-
if inputs[0].dtype not in supported_dtypes:
259-
raise TypeError(
260-
f"IO data type needs to be one of {supported_dtypes}, got "
261-
f'"{inputs[0].dtype}"'
262-
)
252+
validate_valid_dtype(
253+
self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec
254+
)
263255

264256
accumulator_type = ts.DType.INT32
265257

@@ -296,13 +288,12 @@ def define_node(
296288

297289
validate_num_inputs(self.target, inputs, [3, 4, 6])
298290
validate_same_dtype(self.target, [inputs[0], output], ts)
299-
300-
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
301-
if inputs[0].dtype not in supported_dtypes:
302-
raise TypeError(
303-
f"IO data type needs to be one of {supported_dtypes}, got "
304-
f'"{inputs[0].dtype}"'
305-
)
291+
validate_valid_dtype(
292+
self.target,
293+
[inputs[0], output],
294+
[ts.DType.INT8, ts.DType.FP32],
295+
output.tosa_spec,
296+
)
306297

307298
if inputs[0].dtype == ts.DType.INT8:
308299
super().define_node(node, tosa_graph, inputs, output)

backends/arm/operators/op_bmm.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from executorch.backends.arm.operators.operator_validation_utils import (
2121
validate_num_inputs,
2222
validate_same_dtype,
23+
validate_valid_dtype,
2324
)
2425
from executorch.backends.arm.tosa_mapping import TosaArg
2526
from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80
@@ -51,15 +52,12 @@ def define_node(
5152

5253
validate_num_inputs(self.target, inputs, 2)
5354
validate_same_dtype(self.target, [*inputs, output], ts)
54-
55-
# aten.bmm maps directly to MATMUL
56-
# NOTE: For now, only INT8 & FP32 is supported
57-
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
58-
for input in inputs:
59-
if input.dtype not in supported_dtypes:
60-
raise TypeError(
61-
f'IO data type needs to be {supported_dtypes}, got "{input.dtype}"'
62-
)
55+
validate_valid_dtype(
56+
self.target,
57+
[*inputs, output],
58+
[ts.DType.INT8, ts.DType.FP32],
59+
output.tosa_spec,
60+
)
6361

6462
# aten.bmm maps directly to MATMUL
6563

@@ -130,18 +128,14 @@ def define_node(
130128

131129
validate_num_inputs(self.target, inputs, 2)
132130
validate_same_dtype(self.target, [*inputs, output], ts)
131+
validate_valid_dtype(
132+
self.target,
133+
[*inputs, output],
134+
[ts.DType.INT8, ts.DType.FP32],
135+
output.tosa_spec,
136+
)
133137

134138
# aten.bmm maps directly to MATMUL
135-
# NOTE: For now, only INT8 & FP32 is supported
136-
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
137-
for input in inputs:
138-
if input.dtype not in supported_dtypes:
139-
raise TypeError(
140-
f'IO data type needs to be {supported_dtypes}, got "{input.dtype}"'
141-
)
142-
143-
# aten.bmm maps directly to MATMUL
144-
# NOTE: For now, only INT8 & FP32 is supported
145139

146140
# For INT8, we need to get the zero points and add an intermediate tensor
147141
# for a later rescale.

0 commit comments

Comments
 (0)