Skip to content

Commit c0e9a6c

Browse files
authored
Merge branch 'pytorch:main' into toupstream/mm_fix
2 parents a2fecb1 + c9f5f19 commit c0e9a6c

File tree

4 files changed

+131
-5
lines changed

4 files changed

+131
-5
lines changed

backends/arm/arm_vela.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ def vela_compile(tosa_graph, args: List[str], shape_order=None):
9696
block_name = block_name + b"\x00" * (16 - len(block_name))
9797

9898
# We need the acual unpadded block lengths for hw setup
99-
block_length = struct.pack("<iiii", len(bin_blocks[key]), 0, 0, 0) # type: ignore[assignment]
99+
block_length_bytes = struct.pack("<iiii", len(bin_blocks[key]), 0, 0, 0)
100100

101101
# Pad block data to multiple of 16 bytes
102102
block_data = bin_blocks[key]
103103
block_data = block_data + b"\x00" * (15 - (len(block_data) - 1) % 16)
104104

105-
block = block_name + block_length + block_data # type: ignore[operator]
105+
block = block_name + block_length_bytes + block_data
106106
blocks = blocks + block
107107

108108
return blocks

backends/arm/operators/node_visitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def define_node(
4444

4545

4646
# container for all node visitors
47-
_node_visitor_dicts = { # type: ignore[var-annotated]
47+
_node_visitor_dicts: Dict[TosaSpecification, Dict] = {
4848
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
4949
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
5050
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm.test import common, conftest
11+
12+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
13+
14+
from torchaudio.models import Conformer
15+
16+
17+
logger = logging.getLogger(__name__)
18+
logger.setLevel(logging.INFO)
19+
20+
21+
class TestConformer(unittest.TestCase):
22+
"""Tests Torchaudio Conformer"""
23+
24+
# Adjust nbr below as we increase op support. Note: most of the delegates
25+
# calls are directly consecutive to each other in the .pte. The reason
26+
# for that is some assert ops are removed by passes in the
27+
# .to_executorch step, i.e. after Arm partitioner.
28+
ops_after_partitioner = {
29+
"executorch_exir_dialects_edge__ops_aten_arange_start_step": 1,
30+
"executorch_exir_dialects_edge__ops_aten_full_like_default": 4,
31+
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
32+
"executorch_exir_dialects_edge__ops_aten_mul_Scalar": 4,
33+
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
34+
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
35+
"executorch_exir_dialects_edge__ops_aten_logical_not_default": 4,
36+
"executorch_exir_dialects_edge__ops_aten_any_dim": 2,
37+
"torch.ops.aten._assert_scalar.default": 12,
38+
"torch.ops.aten._local_scalar_dense.default": 1,
39+
"torch.ops.aten.scalar_tensor.default": 2,
40+
"torch.ops.higher_order.executorch_call_delegate": 5,
41+
}
42+
43+
dim = 16
44+
lengths = torch.randint(1, 100, (10,), dtype=torch.int32)
45+
input_data = torch.rand(10, int(lengths.max()), dim)
46+
conformer = Conformer(
47+
input_dim=dim,
48+
num_heads=4,
49+
ffn_dim=64,
50+
num_layers=2,
51+
depthwise_conv_kernel_size=31,
52+
)
53+
conformer = conformer.eval()
54+
55+
def test_conformer_tosa_MI(self):
56+
(
57+
ArmTester(
58+
self.conformer,
59+
example_inputs=(self.input_data, self.lengths),
60+
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-0.80+MI"),
61+
)
62+
.export()
63+
.to_edge_transform_and_lower()
64+
.dump_operator_distribution()
65+
.check_count(self.ops_after_partitioner)
66+
.to_executorch()
67+
# TODO(MLETORCH-632): Fix numerical errors
68+
.run_method_and_compare_outputs(
69+
inputs=(self.input_data, self.lengths), rtol=1, atol=5
70+
)
71+
)
72+
73+
@unittest.expectedFailure # TODO(MLETORCH-635)
74+
def test_conformer_tosa_BI(self):
75+
(
76+
ArmTester(
77+
self.conformer,
78+
example_inputs=(self.input_data, self.lengths),
79+
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-0.80+BI"),
80+
)
81+
.quantize()
82+
.export()
83+
.to_edge_transform_and_lower()
84+
.to_executorch()
85+
.run_method_and_compare_outputs(
86+
qtol=1, rtol=1, atol=5, inputs=(self.input_data, self.lengths)
87+
)
88+
)
89+
90+
@unittest.expectedFailure # TODO(MLETORCH-635)
91+
def test_conformer_u55_BI(self):
92+
tester = (
93+
ArmTester(
94+
self.conformer,
95+
example_inputs=(self.input_data, self.lengths),
96+
compile_spec=common.get_u55_compile_spec(),
97+
)
98+
.quantize()
99+
.export()
100+
.to_edge_transform_and_lower()
101+
.to_executorch()
102+
.serialize()
103+
)
104+
if conftest.is_option_enabled("corstone_fvp"):
105+
tester.run_method_and_compare_outputs(
106+
atol=1.0, qtol=1, inputs=(self.input_data, self.lengths)
107+
)
108+
109+
@unittest.expectedFailure # TODO(MLETORCH-635)
110+
def test_conformer_u85_BI(self):
111+
tester = (
112+
ArmTester(
113+
self.conformer,
114+
example_inputs=(self.input_data, self.lengths),
115+
compile_spec=common.get_u85_compile_spec(),
116+
)
117+
.quantize()
118+
.export()
119+
.to_edge_transform_and_lower()
120+
.to_executorch()
121+
.serialize()
122+
)
123+
if conftest.is_option_enabled("corstone_fvp"):
124+
tester.run_method_and_compare_outputs(
125+
atol=1.0, qtol=1, inputs=(self.input_data, self.lengths)
126+
)

backends/arm/util/arm_model_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959
if tosa_output_path:
6060
self.tosa_output_path = tosa_output_path
6161
else:
62-
self.tosa_output_path = None # type: ignore[assignment]
62+
self.tosa_output_path = ""
6363

6464
def get_model_error(self) -> defaultdict:
6565
"""
@@ -104,7 +104,7 @@ def get_compression_ratio(self) -> float:
104104

105105
return compression_ratio
106106

107-
def evaluate(self) -> dict[Any]: # type: ignore[type-arg]
107+
def evaluate(self) -> dict[str, Any]:
108108
model_error_dict = self.get_model_error()
109109

110110
output_metrics = {"name": self.model_name, "metrics": dict(model_error_dict)}

0 commit comments

Comments
 (0)