Skip to content

Commit 0ea3aed

Browse files
authored
Merge branch 'main' into add_hardsigmoid_op
2 parents 06cfdc9 + c9f5f19 commit 0ea3aed

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm.test import common, conftest
11+
12+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
13+
14+
from torchaudio.models import Conformer
15+
16+
17+
logger = logging.getLogger(__name__)
18+
logger.setLevel(logging.INFO)
19+
20+
21+
class TestConformer(unittest.TestCase):
22+
"""Tests Torchaudio Conformer"""
23+
24+
# Adjust nbr below as we increase op support. Note: most of the delegates
25+
# calls are directly consecutive to each other in the .pte. The reason
26+
# for that is some assert ops are removed by passes in the
27+
# .to_executorch step, i.e. after Arm partitioner.
28+
ops_after_partitioner = {
29+
"executorch_exir_dialects_edge__ops_aten_arange_start_step": 1,
30+
"executorch_exir_dialects_edge__ops_aten_full_like_default": 4,
31+
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
32+
"executorch_exir_dialects_edge__ops_aten_mul_Scalar": 4,
33+
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
34+
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
35+
"executorch_exir_dialects_edge__ops_aten_logical_not_default": 4,
36+
"executorch_exir_dialects_edge__ops_aten_any_dim": 2,
37+
"torch.ops.aten._assert_scalar.default": 12,
38+
"torch.ops.aten._local_scalar_dense.default": 1,
39+
"torch.ops.aten.scalar_tensor.default": 2,
40+
"torch.ops.higher_order.executorch_call_delegate": 5,
41+
}
42+
43+
dim = 16
44+
lengths = torch.randint(1, 100, (10,), dtype=torch.int32)
45+
input_data = torch.rand(10, int(lengths.max()), dim)
46+
conformer = Conformer(
47+
input_dim=dim,
48+
num_heads=4,
49+
ffn_dim=64,
50+
num_layers=2,
51+
depthwise_conv_kernel_size=31,
52+
)
53+
conformer = conformer.eval()
54+
55+
def test_conformer_tosa_MI(self):
56+
(
57+
ArmTester(
58+
self.conformer,
59+
example_inputs=(self.input_data, self.lengths),
60+
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-0.80+MI"),
61+
)
62+
.export()
63+
.to_edge_transform_and_lower()
64+
.dump_operator_distribution()
65+
.check_count(self.ops_after_partitioner)
66+
.to_executorch()
67+
# TODO(MLETORCH-632): Fix numerical errors
68+
.run_method_and_compare_outputs(
69+
inputs=(self.input_data, self.lengths), rtol=1, atol=5
70+
)
71+
)
72+
73+
@unittest.expectedFailure # TODO(MLETORCH-635)
74+
def test_conformer_tosa_BI(self):
75+
(
76+
ArmTester(
77+
self.conformer,
78+
example_inputs=(self.input_data, self.lengths),
79+
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-0.80+BI"),
80+
)
81+
.quantize()
82+
.export()
83+
.to_edge_transform_and_lower()
84+
.to_executorch()
85+
.run_method_and_compare_outputs(
86+
qtol=1, rtol=1, atol=5, inputs=(self.input_data, self.lengths)
87+
)
88+
)
89+
90+
@unittest.expectedFailure # TODO(MLETORCH-635)
91+
def test_conformer_u55_BI(self):
92+
tester = (
93+
ArmTester(
94+
self.conformer,
95+
example_inputs=(self.input_data, self.lengths),
96+
compile_spec=common.get_u55_compile_spec(),
97+
)
98+
.quantize()
99+
.export()
100+
.to_edge_transform_and_lower()
101+
.to_executorch()
102+
.serialize()
103+
)
104+
if conftest.is_option_enabled("corstone_fvp"):
105+
tester.run_method_and_compare_outputs(
106+
atol=1.0, qtol=1, inputs=(self.input_data, self.lengths)
107+
)
108+
109+
@unittest.expectedFailure # TODO(MLETORCH-635)
110+
def test_conformer_u85_BI(self):
111+
tester = (
112+
ArmTester(
113+
self.conformer,
114+
example_inputs=(self.input_data, self.lengths),
115+
compile_spec=common.get_u85_compile_spec(),
116+
)
117+
.quantize()
118+
.export()
119+
.to_edge_transform_and_lower()
120+
.to_executorch()
121+
.serialize()
122+
)
123+
if conftest.is_option_enabled("corstone_fvp"):
124+
tester.run_method_and_compare_outputs(
125+
atol=1.0, qtol=1, inputs=(self.input_data, self.lengths)
126+
)

0 commit comments

Comments
 (0)