1212import pytest
1313
1414import torch
15- from executorch .backends .arm .test import common
15+ from executorch .backends .arm .test import common , conftest
1616from executorch .backends .arm .test .tester .arm_tester import ArmTester
1717from executorch .exir .backend .compile_spec_schema import CompileSpec
1818from parameterized import parameterized
3030 lambda : ("randn_neg_dim" , torch .randn (10 , 5 , 8 , 7 ), - 3 ),
3131]
3232
33- test_data_generators_u55 = [
33+ test_data_generators_FVP = [
3434 # (test_name, test_data, dim)
3535 lambda : ("ones" , torch .ones (10 , 10 ), 1 ),
36- lambda : ("ones_neg_dim" , torch .ones (10 , 3 , 4 ), - 1 ),
37- lambda : ("randn_neg_dim" , torch .randn (10 , 5 , 8 , 7 ), - 3 ),
38- lambda : ("zeros" , torch .zeros (10 , 8 , 5 , 2 ), 0 ),
39- lambda : ("zeros_neg_dim" , torch .zeros (10 , 7 , 8 , 9 ), - 4 ),
36+ lambda : ("ones_neg_dim" , torch .ones (1 , 3 , 4 ), - 1 ),
37+ lambda : ("randn_neg_dim" , torch .randn (1 , 5 , 8 , 7 ), - 3 ),
38+ lambda : ("zeros" , torch .zeros (1 , 8 , 5 , 2 ), 0 ),
39+ lambda : ("zeros_neg_dim" , torch .zeros (1 , 7 , 8 , 9 ), - 4 ),
4040 lambda : ("rand" , torch .rand (1 , 2 , 5 , 8 ), 2 ),
41- lambda : ("rand_neg_dim" , torch .rand (2 , 10 , 8 , 10 ), - 2 ),
42- lambda : ("randn" , torch .randn (10 , 10 , 10 , 10 ), 3 ),
41+ lambda : ("rand_neg_dim" , torch .rand (1 , 10 , 8 , 10 ), - 2 ),
42+ lambda : ("randn" , torch .randn (1 , 10 , 10 , 10 ), 3 ),
4343]
4444
4545
@@ -95,13 +95,13 @@ def _test_softmax_tosa_BI_pipeline(
9595 .run_method_and_compare_outputs (inputs = test_data )
9696 )
9797
98- def _test_softmax_tosa_ethos_BI_pipeline (
98+ def _test_softmax_ethosu_BI_pipeline (
9999 self ,
100100 compile_spec : list [CompileSpec ],
101101 module : torch .nn .Module ,
102102 test_data : Tuple [torch .tensor ],
103103 ):
104- (
104+ tester = (
105105 ArmTester (
106106 module ,
107107 example_inputs = test_data ,
@@ -116,21 +116,10 @@ def _test_softmax_tosa_ethos_BI_pipeline(
116116 .check_not (["executorch_exir_dialects_edge__ops_aten__softmax_default" ])
117117 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
118118 .to_executorch ()
119+ .serialize ()
119120 )
120-
121- def _test_softmax_tosa_u55_BI_pipeline (
122- self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
123- ):
124- self ._test_softmax_tosa_ethos_BI_pipeline (
125- common .get_u55_compile_spec (), module , test_data
126- )
127-
128- def _test_softmax_tosa_u85_BI_pipeline (
129- self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
130- ):
131- self ._test_softmax_tosa_ethos_BI_pipeline (
132- common .get_u85_compile_spec (), module , test_data
133- )
121+ if conftest .is_option_enabled ("corstone_fvp" ):
122+ tester .run_method_and_compare_outputs (inputs = test_data , qtol = 1 )
134123
135124 @parameterized .expand (test_data_generators )
136125 def test_softmax_tosa_MI (self , test_data_generator : Callable [[], Tuple ]):
@@ -143,14 +132,18 @@ def test_softmax_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
143132 test_name , test_data , dim = test_data_generator ()
144133 self ._test_softmax_tosa_BI_pipeline (self .Softmax (dim = dim ), (test_data ,))
145134
146- @parameterized .expand (test_data_generators_u55 )
135+ @parameterized .expand (test_data_generators_FVP )
147136 @pytest .mark .flaky # TODO: MLETORCH-460 - Numerically stabler (log)softmax implementation
148- def test_softmax_tosa_u55_BI (self , test_data_generator : Callable [[], Tuple ]):
137+ def test_softmax_u55_BI (self , test_data_generator : Callable [[], Tuple ]):
149138 test_name , test_data , dim = test_data_generator ()
150- self ._test_softmax_tosa_u55_BI_pipeline (self .Softmax (dim = dim ), (test_data ,))
139+ self ._test_softmax_ethosu_BI_pipeline (
140+ common .get_u55_compile_spec (), self .Softmax (dim = dim ), (test_data ,)
141+ )
151142
152- @parameterized .expand (test_data_generators )
143+ @parameterized .expand (test_data_generators_FVP )
153144 @pytest .mark .flaky # TODO: MLETORCH-460 - Numerically stabler (log)softmax implementation
154- def test_softmax_tosa_u85_BI (self , test_data_generator : Callable [[], Tuple ]):
145+ def test_softmax_u85_BI (self , test_data_generator : Callable [[], Tuple ]):
155146 test_name , test_data , dim = test_data_generator ()
156- self ._test_softmax_tosa_u85_BI_pipeline (self .Softmax (dim = dim ), (test_data ,))
147+ self ._test_softmax_ethosu_BI_pipeline (
148+ common .get_u85_compile_spec (), self .Softmax (dim = dim ), (test_data ,)
149+ )
0 commit comments