Skip to content

Commit ed20baf

Browse files
authored
Merge branch 'main' into aoti_support_multi_method
2 parents 7074422 + 7ed9266 commit ed20baf

File tree

162 files changed

+3700
-1889
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

162 files changed

+3700
-1889
lines changed

backends/arm/operator_support/pool_2d_support.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Provide TOSA support checks for 2D pooling.
6+
7+
Validate ``avg_pool2d`` and ``max_pool2d_with_indices`` against U55 profile
8+
constraints including kernel size, stride, padding, and dimensionality.
9+
10+
"""
511

612
from typing import cast
713

@@ -20,16 +26,48 @@
2026

2127

2228
def kernel_check(kernel: tuple[int, int]) -> bool:
29+
"""Check if kernel size is within U55 constraints.
30+
31+
Checks that ``kernel_x * kernel_y`` is in ``[1, 65536]`` and
32+
``kernel_y`` is in ``[1, 256]`` as required by the U55 profile.
33+
34+
Args:
35+
kernel (tuple[int, int]): Kernel height and width ``(kh, kw)``.
36+
37+
Returns:
38+
bool: True if the kernel passes validation.
39+
40+
"""
2341
if not (1 <= kernel[0] * kernel[1] <= 65536):
2442
return False
2543
return 1 <= kernel[1] <= 256
2644

2745

2846
def stride_check(strides: tuple[int, int]) -> bool:
47+
"""Check if strides are within U55 constraints.
48+
49+
Args:
50+
strides (tuple[int, int]): Vertical and horizontal strides.
51+
52+
Returns:
53+
bool: True if each stride is in ``[1, 3]``.
54+
55+
"""
2956
return all(1 <= stride <= 3 for stride in strides)
3057

3158

3259
def dim_check(shape=torch.Size) -> bool:
60+
"""Check if non-batch dims are within U55 constraints.
61+
62+
Verifies that all dimensions except batch are in ``[1, 65536]``.
63+
64+
Args:
65+
shape (torch.Size): Input tensor shape.
66+
67+
Returns:
68+
bool: True if all checked dimensions pass.
69+
70+
"""
3371
check = True
3472
for dim in shape[1:]:
3573
check &= 1 <= dim <= 65536
@@ -38,6 +76,13 @@ def dim_check(shape=torch.Size) -> bool:
3876

3977
@register_tosa_support_check
4078
class AvgPool2dSupported(SupportedTOSAOperatorCheck):
79+
"""Provide TOSA support checks for ``aten.avg_pool2d``.
80+
81+
Applies additional constraints when targeting the U55 subset, including
82+
limits on kernel size, stride, padding behavior, and tensor ranks.
83+
84+
"""
85+
4186
targets = [
4287
exir_ops.edge.aten.avg_pool2d.default,
4388
]
@@ -48,6 +93,12 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
4893
]
4994

5095
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
96+
"""Return True if ``avg_pool2d`` satisfies U55 constraints.
97+
98+
Computes the effective TOSA padding (depending on ``count_include_pad``
99+
and ``divisor_override``) and validates kernel, stride, and shape limits.
100+
101+
"""
51102
if not tosa_spec.is_U55_subset:
52103
return True
53104

@@ -115,6 +166,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
115166

116167
@register_tosa_support_check
117168
class MaxPool2dSupported(SupportedTOSAOperatorCheck):
169+
"""Provide TOSA support checks for ``aten.max_pool2d_with_indices``.
170+
171+
Applies additional constraints when targeting the U55 subset, including
172+
limits on kernel size, stride, and tensor ranks.
173+
174+
"""
175+
118176
targets = [
119177
exir_ops.edge.aten.max_pool2d_with_indices.default,
120178
]
@@ -125,6 +183,9 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
125183
]
126184

127185
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
186+
"""Return True if ``max_pool2d_with_indices`` satisfies U55
187+
constraints.
188+
"""
128189
if not tosa_spec.is_U55_subset:
129190
return True
130191

backends/arm/operators/op_sub.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def define_node(
5050
validate_valid_dtype(
5151
self.target,
5252
[*inputs, output],
53-
[ts.DType.INT8, ts.DType.INT32],
53+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
5454
output.tosa_spec,
5555
)
5656

@@ -59,12 +59,18 @@ def define_node(
5959
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
6060
tosa_graph, inputs, node, self.tosa_spec
6161
)
62+
elif inputs[0].dtype == ts.DType.INT16:
63+
rescaled_inputs, scale_back = (
64+
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
65+
tosa_graph, inputs, node, self.tosa_spec
66+
)
67+
)
6268
else:
6369
# input[0].dtype == ts.DType.INT32
6470
# Non quantized input, natively support by TOSA.SUB
6571
rescaled_inputs = inputs
6672

67-
if output.dtype == ts.DType.INT8:
73+
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
6874
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
6975
sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
7076
else:
@@ -95,6 +101,15 @@ def define_node(
95101
compute_rescale=False,
96102
tosa_spec=self.tosa_spec,
97103
) # type: ignore[possibly-undefined]
104+
elif output.dtype == ts.DType.INT16:
105+
tqutils.insert_rescale_op_to_int16(
106+
tosa_graph,
107+
sub_output,
108+
scale_back,
109+
node,
110+
compute_rescale=False,
111+
tosa_spec=self.tosa_spec,
112+
) # type: ignore[possibly-undefined]
98113

99114

100115
@register_node_visitor

backends/arm/test/misc/test_conv_relu_residual_add.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def test_tosa_u55_INT(per_channel_quantization):
8585
model_inputs,
8686
[],
8787
[],
88-
run_on_fvp=True,
8988
use_to_edge_transform_and_lower=True,
9089
per_channel_quantization=per_channel_quantization,
9190
qtol=0,
@@ -102,7 +101,6 @@ def test_tosa_u85_INT(per_channel_quantization):
102101
model_inputs,
103102
[],
104103
[],
105-
run_on_fvp=True,
106104
use_to_edge_transform_and_lower=True,
107105
per_channel_quantization=per_channel_quantization,
108106
qtol=0,

backends/arm/test/misc/test_debug_feats.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,10 @@ def forward(self, x):
262262

263263

264264
@common.parametrize("test_data", Add.inputs)
265+
@common.XfailIfNoCorstone300
265266
def test_fail_dump_tosa_ops(caplog, test_data: input_t1):
266267
pipeline = EthosU55PipelineINT[input_t1](
267-
Add(), test_data, [], [], use_to_edge_transform_and_lower=True, run_on_fvp=False
268+
Add(), test_data, [], [], use_to_edge_transform_and_lower=True
268269
)
269270
pipeline.dump_operator_distribution("to_edge_transform_and_lower")
270271
pipeline.run()

backends/arm/test/models/test_conformer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def test_conformer_u55_INT():
9292
aten_ops=TestConformer.aten_ops,
9393
exir_ops=[],
9494
use_to_edge_transform_and_lower=True,
95-
run_on_fvp=True,
9695
)
9796
pipeline.change_args(
9897
"run_method_and_compare_outputs",
@@ -114,7 +113,6 @@ def test_conformer_u85_INT():
114113
aten_ops=TestConformer.aten_ops,
115114
exir_ops=[],
116115
use_to_edge_transform_and_lower=True,
117-
run_on_fvp=True,
118116
)
119117
pipeline.change_args(
120118
"run_method_and_compare_outputs",

backends/arm/test/models/test_dl3_arm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def test_dl3_u55_INT():
6666
TestDl3.model_example_inputs,
6767
aten_ops=[],
6868
exir_ops=[],
69-
run_on_fvp=True,
7069
)
7170
pipeline.change_args(
7271
"run_method_and_compare_outputs", rtol=1.0, atol=1.0
@@ -82,7 +81,6 @@ def test_dl3_u85_INT():
8281
TestDl3.model_example_inputs,
8382
aten_ops=[],
8483
exir_ops=[],
85-
run_on_fvp=True,
8684
)
8785
pipeline.change_args(
8886
"run_method_and_compare_outputs", rtol=1.0, atol=1.0

backends/arm/test/models/test_inception_v3_arm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def test_ic3_u55_BI():
6666
model_inputs,
6767
aten_ops=[],
6868
exir_ops=[],
69-
run_on_fvp=True,
7069
use_to_edge_transform_and_lower=True,
7170
atol=0.6,
7271
qtol=1,
@@ -83,7 +82,6 @@ def test_ic3_u85_BI():
8382
model_inputs,
8483
aten_ops=[],
8584
exir_ops=[],
86-
run_on_fvp=True,
8785
use_to_edge_transform_and_lower=True,
8886
atol=0.6,
8987
qtol=1,

backends/arm/test/models/test_lstm_arm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def test_lstm_u55_INT():
7777
aten_ops=[],
7878
exir_ops=[],
7979
use_to_edge_transform_and_lower=True,
80-
run_on_fvp=True,
8180
)
8281
pipeline.change_args(
8382
"run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0
@@ -93,7 +92,6 @@ def test_lstm_u85_INT():
9392
aten_ops=[],
9493
exir_ops=[],
9594
use_to_edge_transform_and_lower=True,
96-
run_on_fvp=True,
9795
)
9896
pipeline.change_args(
9997
"run_method_and_compare_outputs", get_test_inputs(), atol=3e-1, qtol=1.0

backends/arm/test/models/test_mobilenet_v2_arm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def test_mv2_u55_INT(per_channel_quantization):
8787
model_inputs,
8888
aten_ops=[],
8989
exir_ops=[],
90-
run_on_fvp=True,
9190
use_to_edge_transform_and_lower=True,
9291
per_channel_quantization=per_channel_quantization,
9392
atol=0.25,
@@ -105,7 +104,6 @@ def test_mv2_u85_INT(per_channel_quantization):
105104
model_inputs,
106105
aten_ops=[],
107106
exir_ops=[],
108-
run_on_fvp=True,
109107
use_to_edge_transform_and_lower=True,
110108
per_channel_quantization=per_channel_quantization,
111109
atol=0.25,

backends/arm/test/models/test_mobilenet_v3_arm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def test_mv3_u55_INT():
6161
model_inputs,
6262
aten_ops=[],
6363
exir_ops=[],
64-
run_on_fvp=True,
6564
use_to_edge_transform_and_lower=True,
6665
atol=0.5,
6766
qtol=1,
@@ -77,7 +76,6 @@ def test_mv3_u85_INT():
7776
model_inputs,
7877
aten_ops=[],
7978
exir_ops=[],
80-
run_on_fvp=True,
8179
use_to_edge_transform_and_lower=True,
8280
atol=0.5,
8381
qtol=1,

0 commit comments

Comments
 (0)