Skip to content

Commit d536d18

Browse files
Arm backend: Improve dtype validation (pytorch#15871)
Improve dtype validation in node-visitors. Signed-off-by: Oscar Andersson <[email protected]>
1 parent de5962d commit d536d18

26 files changed

+223
-72
lines changed

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,12 @@ def __init__(self, reporter: WhyNoPartitionReporter):
7878

7979
targeted_ops_i8_i16_i32 = [
8080
exir_ops.edge.aten.cat.default,
81+
exir_ops.edge.aten.expand_copy.default,
8182
exir_ops.edge.aten.repeat.default,
8283
exir_ops.edge.aten.constant_pad_nd.default,
8384
exir_ops.edge.aten.view.default,
8485
exir_ops.edge.aten.permute.default,
86+
exir_ops.edge.aten.permute_copy.default,
8587
]
8688

8789
target_ops_i8 = tuple(TableOps.included_ops())

backends/arm/operators/op_avg_pool2d.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,13 @@ def define_node(
115115
) -> None:
116116
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
117117
validate_same_dtype(self.target, [inputs[0], output], ts)
118+
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
119+
if self.tosa_spec.support_extension("int16"):
120+
supported_dtypes.append(ts.DType.INT16)
118121
validate_valid_dtype(
119122
self.target,
120123
[inputs[0], output],
121-
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP32],
124+
supported_dtypes,
122125
output.tosa_spec,
123126
)
124127

backends/arm/operators/op_cat.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
)
1515
from executorch.backends.arm.operators.operator_validation_utils import (
1616
validate_num_inputs,
17+
validate_same_dtype,
18+
validate_valid_dtype,
1719
)
1820
from executorch.backends.arm.tosa.mapping import TosaArg
1921
from torch.fx import Node
@@ -35,9 +37,19 @@ def define_node(
3537
inputs: List[TosaArg],
3638
output: TosaArg,
3739
) -> None:
40+
supported_dtypes = [ts.DType.BOOL, ts.DType.INT8, ts.DType.INT32, ts.DType.FP32]
41+
if self.tosa_spec.support_extension("int16"):
42+
supported_dtypes.append(ts.DType.INT16)
3843
validate_num_inputs(self.target, inputs, [1, 2])
44+
input_tosa_args = [TosaArg(arg, output.tosa_spec) for arg in inputs[0].special]
45+
validate_same_dtype(self.target, [*input_tosa_args, output], ts)
46+
validate_valid_dtype(
47+
self.target,
48+
[*input_tosa_args, output],
49+
supported_dtypes,
50+
output.tosa_spec,
51+
)
3952

40-
tensors = inputs[0].special
4153
dim = 0 if len(inputs) < 2 else inputs[1].number
4254
rank = len(output.shape)
4355
dim = (dim + rank) % rank
@@ -50,7 +62,7 @@ def define_node(
5062
node,
5163
tosa_graph,
5264
ts.Op.CONCAT,
53-
[tensor.name for tensor in tensors],
65+
[tensor.name for tensor in input_tosa_args],
5466
[output.name],
5567
attr,
5668
)

backends/arm/operators/op_clamp.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,18 @@ def define_node(
8787
) -> None:
8888
validate_num_inputs(self.target, inputs, [2, 3])
8989
validate_same_dtype(self.target, [inputs[0], output], ts)
90+
supported_dtypes = [ts.DType.INT8, ts.DType.FP16, ts.DType.FP32]
91+
if self.tosa_spec.support_extension("int16"):
92+
supported_dtypes.append(ts.DType.INT16)
9093
validate_valid_dtype(
9194
self.target,
9295
[inputs[0], output],
93-
[
94-
ts.DType.INT8,
95-
ts.DType.INT16,
96-
ts.DType.FP16,
97-
ts.DType.FP32,
98-
],
96+
supported_dtypes,
9997
output.tosa_spec,
10098
)
10199

102100
node_input_dtype = node.meta["val"].dtype
103-
# NOTE: Quantization of the min/max arguments is handled by QuantizeClampArgumentsPass
101+
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
104102
min_val, max_val = self._get_min_max_arguments(node, node_input_dtype)
105103

106104
attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_eq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
inputs,
50-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT32, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)

backends/arm/operators/op_ge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
inputs,
50-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT32, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)

backends/arm/operators/op_gt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
inputs,
50-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT32, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)

backends/arm/operators/op_index_select.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
NodeVisitor,
1313
register_node_visitor,
1414
)
15+
from executorch.backends.arm.operators.operator_validation_utils import (
16+
validate_num_inputs,
17+
validate_same_dtype,
18+
validate_valid_dtype,
19+
)
1520
from executorch.backends.arm.tosa.mapping import TosaArg
1621

1722
from executorch.backends.arm.tosa.utils import build_reshape_tosa_1_0
@@ -45,10 +50,16 @@ def define_node(
4550
inputs: List[TosaArg],
4651
output: TosaArg,
4752
) -> None:
48-
if len(inputs) != 3:
49-
raise ValueError(f"Number of inputs are not 3: {len(inputs)}")
53+
validate_num_inputs(self.target, inputs, 3)
54+
validate_same_dtype(self.target, [inputs[0], output], ts)
55+
validate_valid_dtype(
56+
self.target,
57+
[inputs[0], output],
58+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
59+
output.tosa_spec,
60+
)
5061

51-
weights, index, indices = inputs
62+
weights, _, indices = inputs
5263

5364
if len(weights.shape) == 2:
5465
weights_new_shape = [1, weights.shape[0], weights.shape[1]]

backends/arm/operators/op_le.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
inputs,
50-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT32, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)

backends/arm/operators/op_lt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
inputs,
50-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT32, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)

0 commit comments

Comments
 (0)