33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6- import unittest
6+ from typing import Tuple
7+
8+ import pytest
79
810import torch
9- from executorch .backends .arm .test import common , conftest
1011
11- from executorch .backends .arm .test .tester .arm_tester import ArmTester
12+ from executorch .backends .arm .test import common
13+ from executorch .backends .arm .test .tester .test_pipeline import (
14+ EthosU55PipelineBI ,
15+ EthosU85PipelineBI ,
16+ TosaPipelineBI ,
17+ TosaPipelineMI ,
18+ )
1219
1320from torchaudio .models import Conformer
1421
22+ input_t = Tuple [torch .Tensor , torch .IntTensor ] # Input x, y
23+
1524
1625def get_test_inputs (dim , lengths , num_examples ):
1726 return (torch .rand (num_examples , int (lengths .max ()), dim ), lengths )
1827
1928
20- class TestConformer ( unittest . TestCase ) :
29+ class TestConformer :
2130 """Tests Torchaudio Conformer"""
2231
2332 # Adjust nbr below as we increase op support. Note: most of the delegates
2433 # calls are directly consecutive to each other in the .pte. The reason
2534 # for that is some assert ops are removed by passes in the
2635 # .to_executorch step, i.e. after Arm partitioner.
27- ops_after_partitioner = {
28- "executorch_exir_dialects_edge__ops_aten_max_default" : 1 ,
29- "torch.ops.aten._assert_scalar.default" : 7 ,
30- "torch.ops.aten._local_scalar_dense.default" : 1 ,
31- }
36+ aten_ops = ["torch.ops.aten._assert_scalar.default" ]
3237
3338 dim = 16
3439 num_examples = 10
@@ -43,96 +48,87 @@ class TestConformer(unittest.TestCase):
4348 )
4449 conformer = conformer .eval ()
4550
46- def test_conformer_tosa_MI (self ):
47- (
48- ArmTester (
49- self .conformer ,
50- example_inputs = self .model_example_inputs ,
51- compile_spec = common .get_tosa_compile_spec (tosa_spec = "TOSA-0.80+MI" ),
52- )
53- .export ()
54- .to_edge_transform_and_lower ()
55- .dump_operator_distribution ()
56- .check_count (self .ops_after_partitioner )
57- .to_executorch ()
58- # TODO(MLETORCH-632): Fix numerical errors
59- .run_method_and_compare_outputs (
60- rtol = 1.0 ,
61- atol = 5.0 ,
62- inputs = get_test_inputs (self .dim , self .lengths , self .num_examples ),
63- )
64- )
65-
66- def test_conformer_tosa_BI (self ):
67- (
68- ArmTester (
69- self .conformer ,
70- example_inputs = self .model_example_inputs ,
71- compile_spec = common .get_tosa_compile_spec (tosa_spec = "TOSA-0.80+BI" ),
72- )
73- .quantize ()
74- .export ()
75- .to_edge_transform_and_lower ()
76- .to_executorch ()
77- .run_method_and_compare_outputs (
78- qtol = 1.0 ,
79- rtol = 1.0 ,
80- atol = 5.0 ,
81- inputs = get_test_inputs (self .dim , self .lengths , self .num_examples ),
82- )
83- )
84-
85- def test_conformer_u55_BI (self ):
86- tester = (
87- ArmTester (
88- self .conformer ,
89- example_inputs = self .model_example_inputs ,
90- compile_spec = common .get_u55_compile_spec (),
91- )
92- .quantize ()
93- .export ()
94- .to_edge_transform_and_lower ()
95- .to_executorch ()
96- .serialize ()
97- )
98-
99- if conftest .is_option_enabled ("corstone_fvp" ):
100- try :
101- tester .run_method_and_compare_outputs (
102- qtol = 1.0 ,
103- rtol = 1.0 ,
104- atol = 5.0 ,
105- inputs = get_test_inputs (self .dim , self .lengths , self .num_examples ),
106- )
107- self .fail (
108- "TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
109- )
110- except Exception :
111- pass
112-
113- def test_conformer_u85_BI (self ):
114- tester = (
115- ArmTester (
116- self .conformer ,
117- example_inputs = self .model_example_inputs ,
118- compile_spec = common .get_u85_compile_spec (),
119- )
120- .quantize ()
121- .export ()
122- .to_edge_transform_and_lower ()
123- .to_executorch ()
124- .serialize ()
125- )
126- if conftest .is_option_enabled ("corstone_fvp" ):
127- try :
128- tester .run_method_and_compare_outputs (
129- qtol = 1.0 ,
130- rtol = 1.0 ,
131- atol = 5.0 ,
132- inputs = get_test_inputs (self .dim , self .lengths , self .num_examples ),
133- )
134- self .fail (
135- "TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
136- )
137- except Exception :
138- pass
51+
52+ def test_conformer_tosa_MI ():
53+ pipeline = TosaPipelineMI [input_t ](
54+ TestConformer .conformer ,
55+ TestConformer .model_example_inputs ,
56+ aten_op = TestConformer .aten_ops ,
57+ exir_op = [],
58+ use_to_edge_transform_and_lower = True ,
59+ )
60+ pipeline .change_args (
61+ "run_method_and_compare_outputs" ,
62+ get_test_inputs (
63+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
64+ ),
65+ rtol = 1.0 ,
66+ atol = 5.0 ,
67+ )
68+ pipeline .run ()
69+
70+
71+ def test_conformer_tosa_BI ():
72+ pipeline = TosaPipelineBI [input_t ](
73+ TestConformer .conformer ,
74+ TestConformer .model_example_inputs ,
75+ aten_op = TestConformer .aten_ops ,
76+ exir_op = [],
77+ use_to_edge_transform_and_lower = True ,
78+ )
79+ pipeline .pop_stage ("check_count.exir" )
80+ pipeline .change_args (
81+ "run_method_and_compare_outputs" ,
82+ get_test_inputs (
83+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
84+ ),
85+ rtol = 1.0 ,
86+ atol = 5.0 ,
87+ )
88+ pipeline .run ()
89+
90+
91+ @common .XfailIfNoCorstone300
92+ @pytest .mark .xfail (
93+ reason = "TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
94+ )
95+ def test_conformer_u55_BI ():
96+ pipeline = EthosU55PipelineBI [input_t ](
97+ TestConformer .conformer ,
98+ TestConformer .model_example_inputs ,
99+ aten_ops = TestConformer .aten_ops ,
100+ exir_ops = [],
101+ use_to_edge_transform_and_lower = True ,
102+ run_on_fvp = True ,
103+ )
104+ pipeline .change_args (
105+ "run_method_and_compare_outputs" ,
106+ get_test_inputs (
107+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
108+ ),
109+ rtol = 1.0 ,
110+ atol = 5.0 ,
111+ )
112+ pipeline .run ()
113+
114+
115+ @common .XfailIfNoCorstone320
116+ @pytest .mark .xfail (reason = "All IO needs to have the same data type (MLETORCH-635)" )
117+ def test_conformer_u85_BI ():
118+ pipeline = EthosU85PipelineBI [input_t ](
119+ TestConformer .conformer ,
120+ TestConformer .model_example_inputs ,
121+ aten_ops = TestConformer .aten_ops ,
122+ exir_ops = [],
123+ use_to_edge_transform_and_lower = True ,
124+ run_on_fvp = True ,
125+ )
126+ pipeline .change_args (
127+ "run_method_and_compare_outputs" ,
128+ get_test_inputs (
129+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
130+ ),
131+ rtol = 1.0 ,
132+ atol = 5.0 ,
133+ )
134+ pipeline .run ()
0 commit comments