33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6- import unittest
6+ from typing import Tuple
7+
8+ import pytest
79
810import torch
9- from executorch .backends .arm .test import common , conftest
1011
11- from executorch .backends .arm .test .tester .arm_tester import ArmTester
12+ from executorch .backends .arm .test import common
13+ from executorch .backends .arm .test .tester .test_pipeline import (
14+ EthosU55PipelineBI ,
15+ EthosU85PipelineBI ,
16+ TosaPipelineBI ,
17+ TosaPipelineMI ,
18+ )
1219
1320from torchaudio .models import Conformer
1421
22+ input_t = Tuple [torch .Tensor , torch .IntTensor ] # Input x, y
23+
1524
1625def get_test_inputs (dim , lengths , num_examples ):
1726 return (torch .rand (num_examples , int (lengths .max ()), dim ), lengths )
1827
1928
20- class TestConformer ( unittest . TestCase ) :
29+ class TestConformer :
2130 """Tests Torchaudio Conformer"""
2231
2332 # Adjust nbr below as we increase op support. Note: most of the delegates
2433 # calls are directly consecutive to each other in the .pte. The reason
2534 # for that is some assert ops are removed by passes in the
2635 # .to_executorch step, i.e. after Arm partitioner.
27- ops_after_partitioner = {
28- "executorch_exir_dialects_edge__ops_aten_max_default" : 1 ,
29- "torch.ops.aten._assert_scalar.default" : 7 ,
30- "torch.ops.aten._local_scalar_dense.default" : 1 ,
31- }
36+ aten_ops = ["torch.ops.aten._assert_scalar.default" ]
37+ exir_ops = ["executorch_exir_dialects_edge__ops_aten_max_default" ]
3238
3339 dim = 16
3440 num_examples = 10
@@ -43,92 +49,87 @@ class TestConformer(unittest.TestCase):
4349 )
4450 conformer = conformer .eval ()
4551
46- def test_conformer_tosa_MI (self ):
47- (
48- ArmTester (
49- self .conformer ,
50- example_inputs = self .model_example_inputs ,
51- compile_spec = common .get_tosa_compile_spec (tosa_spec = "TOSA-0.80+MI" ),
52- )
53- .export ()
54- .to_edge_transform_and_lower ()
55- .dump_operator_distribution ()
56- .check_count (self .ops_after_partitioner )
57- .to_executorch ()
58- # TODO(MLETORCH-632): Fix numerical errors
59- .run_method_and_compare_outputs (
60- rtol = 1.0 ,
61- atol = 5.0 ,
62- inputs = get_test_inputs (self .dim , self .lengths , self .num_examples ),
63- )
64- )
65-
66- @unittest .expectedFailure # TODO(MLETORCH-635)
67- def test_conformer_tosa_BI (self ):
68- (
69- ArmTester (
70- self .conformer ,
71- example_inputs = self .model_example_inputs ,
72- compile_spec = common .get_tosa_compile_spec (tosa_spec = "TOSA-0.80+BI" ),
73- )
74- .quantize ()
75- .export ()
76- .to_edge_transform_and_lower ()
77- .to_executorch ()
78- .run_method_and_compare_outputs (
79- qtol = 1.0 ,
80- rtol = 1.0 ,
81- atol = 5.0 ,
82- inputs = get_test_inputs (self .dim , self .lengths , self .num_examples ),
83- )
84- )
85-
86- def test_conformer_u55_BI (self ):
87- tester = (
88- ArmTester (
89- self .conformer ,
90- example_inputs = self .model_example_inputs ,
91- compile_spec = common .get_u55_compile_spec (),
92- )
93- .quantize ()
94- .export ()
95- .to_edge_transform_and_lower ()
96- .to_executorch ()
97- .serialize ()
98- )
99-
100- if conftest .is_option_enabled ("corstone_fvp" ):
101- try :
102- tester .run_method_and_compare_outputs (
103- qtol = 1.0 ,
104- rtol = 1.0 ,
105- atol = 5.0 ,
106- inputs = get_test_inputs (self .dim , self .lengths , self .num_examples ),
107- )
108- self .fail (
109- "TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
110- )
111- except Exception :
112- pass
113-
114- @unittest .expectedFailure # TODO(MLETORCH-635)
115- def test_conformer_u85_BI (self ):
116- tester = (
117- ArmTester (
118- self .conformer ,
119- example_inputs = self .model_example_inputs ,
120- compile_spec = common .get_u85_compile_spec (),
121- )
122- .quantize ()
123- .export ()
124- .to_edge_transform_and_lower ()
125- .to_executorch ()
126- .serialize ()
127- )
128- if conftest .is_option_enabled ("corstone_fvp" ):
129- tester .run_method_and_compare_outputs (
130- qtol = 1.0 ,
131- rtol = 1.0 ,
132- atol = 5.0 ,
133- inputs = get_test_inputs (self .dim , self .lengths , self .num_examples ),
134- )
52+
53+ def test_conformer_tosa_MI ():
54+ pipeline = TosaPipelineMI [input_t ](
55+ TestConformer .conformer ,
56+ TestConformer .model_example_inputs ,
57+ aten_op = TestConformer .aten_ops ,
58+ exir_op = [],
59+ use_to_edge_transform_and_lower = True ,
60+ )
61+ pipeline .change_args (
62+ "run_method_and_compare_outputs" ,
63+ get_test_inputs (
64+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
65+ ),
66+ rtol = 1.0 ,
67+ atol = 5.0 ,
68+ )
69+ pipeline .run ()
70+
71+
72+ @pytest .mark .xfail (reason = "All IO needs to have the same data type (MLETORCH-635)" )
73+ def test_conformer_tosa_BI ():
74+ pipeline = TosaPipelineBI [input_t ](
75+ TestConformer .conformer ,
76+ TestConformer .model_example_inputs ,
77+ aten_op = TestConformer .aten_ops ,
78+ exir_op = TestConformer .exir_ops ,
79+ use_to_edge_transform_and_lower = True ,
80+ )
81+ pipeline .change_args (
82+ "run_method_and_compare_outputs" ,
83+ get_test_inputs (
84+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
85+ ),
86+ rtol = 1.0 ,
87+ atol = 5.0 ,
88+ )
89+ pipeline .run ()
90+
91+
92+ @common .XfailIfNoCorstone300
93+ @pytest .mark .xfail (
94+ reason = "TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
95+ )
96+ def test_conformer_u55_BI ():
97+ pipeline = EthosU55PipelineBI [input_t ](
98+ TestConformer .conformer ,
99+ TestConformer .model_example_inputs ,
100+ aten_ops = TestConformer .aten_ops ,
101+ exir_ops = TestConformer .exir_ops ,
102+ use_to_edge_transform_and_lower = True ,
103+ run_on_fvp = True ,
104+ )
105+ pipeline .change_args (
106+ "run_method_and_compare_outputs" ,
107+ get_test_inputs (
108+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
109+ ),
110+ rtol = 1.0 ,
111+ atol = 5.0 ,
112+ )
113+ pipeline .run ()
114+
115+
116+ @common .XfailIfNoCorstone320
117+ @pytest .mark .xfail (reason = "All IO needs to have the same data type (MLETORCH-635)" )
118+ def test_conformer_u85_BI ():
119+ pipeline = EthosU85PipelineBI [input_t ](
120+ TestConformer .conformer ,
121+ TestConformer .model_example_inputs ,
122+ aten_ops = TestConformer .aten_ops ,
123+ exir_ops = TestConformer .exir_ops ,
124+ use_to_edge_transform_and_lower = True ,
125+ run_on_fvp = True ,
126+ )
127+ pipeline .change_args (
128+ "run_method_and_compare_outputs" ,
129+ get_test_inputs (
130+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
131+ ),
132+ rtol = 1.0 ,
133+ atol = 5.0 ,
134+ )
135+ pipeline .run ()
0 commit comments