33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ import pytest
67import torch
78from executorch .backends .arm .quantizer .arm_quantizer import TOSAQuantizer
89from executorch .backends .arm .quantizer .quantization_config import QuantizationConfig
@@ -52,7 +53,7 @@ def _get_32_bit_quant_config():
5253 return qconfig
5354
5455
55- def get_16bit_sigmoid_quantizer (tosa_str : str ):
56+ def get_32bit_sigmoid_quantizer (tosa_str : str ):
5657 tosa_spec = common .TosaSpecification .create_from_string (tosa_str )
5758 quantizer = TOSAQuantizer (tosa_spec )
5859 quantizer .set_global (_get_32_bit_quant_config ())
@@ -65,12 +66,12 @@ def get_16bit_sigmoid_quantizer(tosa_str: str):
6566
6667input_t = tuple [torch .Tensor ]
6768test_data_suite = {
68- "ones" : ( torch .ones (10 , 10 , 10 ), ),
69- "rand" : ( torch .rand (10 , 10 ) - 0.5 ,) ,
70- "rand_4d" : ( torch .rand (1 , 10 , 10 , 10 ), ),
71- "randn_pos" : ( torch .randn (10 ) + 10 ,) ,
72- "randn_neg" : ( torch .randn (10 ) - 10 ,) ,
73- "ramp" : ( torch .arange (- 16 , 16 , 0.2 ), ),
69+ "ones" : lambda : torch .ones (10 , 10 , 10 ),
70+ "rand" : lambda : torch .rand (10 , 10 ) - 0.5 ,
71+ "rand_4d" : lambda : torch .rand (1 , 10 , 10 , 10 ),
72+ "randn_pos" : lambda : torch .randn (10 ) + 10 ,
73+ "randn_neg" : lambda : torch .randn (10 ) - 10 ,
74+ "ramp" : lambda : torch .arange (- 16 , 16 , 0.2 ),
7475}
7576
7677
@@ -96,28 +97,28 @@ def forward(self, x):
9697
9798
9899@common .parametrize ("test_data" , test_data_suite )
100+ @pytest .mark .flaky (reruns = 5 )
99101def test_sigmoid_tosa_BI (test_data ):
100102 pipeline = TosaPipelineBI (
101103 Sigmoid (),
102- test_data ,
104+ ( test_data (),) ,
103105 Sigmoid .aten_op ,
104106 Sigmoid .exir_op ,
105107 )
106- pipeline .change_args ("quantize" , get_16bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
108+ pipeline .change_args ("quantize" , get_32bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
107109 pipeline .run ()
108110
109111
110112@common .parametrize ("test_data" , test_data_suite )
113+ @pytest .mark .flaky (reruns = 5 )
111114def test_sigmoid_add_sigmoid_tosa_BI (test_data ):
112115 pipeline = TosaPipelineBI (
113116 SigmoidAddSigmoid (),
114- test_data ,
117+ ( test_data (),) ,
115118 Sigmoid .aten_op ,
116119 Sigmoid .exir_op ,
117120 )
118- pipeline .change_args ("quantize" , get_16bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
119- pipeline .change_args ("run_method_and_compare_outputs" , test_data , qtol = 1 )
120-
121+ pipeline .change_args ("quantize" , get_32bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
121122 pipeline .run ()
122123
123124
@@ -129,16 +130,19 @@ def test_sigmoid_add_sigmoid_tosa_BI(test_data):
129130 "rand" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
130131 "rand_4d" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
131132 "randn_pos" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
133+ "randn_neg" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
132134 "ramp" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
133135 },
136+ # int16 tables are not supported, but some tests happen to pass regardless.
137+ # Set them to xfail but strict=False -> ok if they pass.
138+ strict = False ,
134139)
135140@common .XfailIfNoCorstone300
136141def test_sigmoid_tosa_u55 (test_data ):
137142 pipeline = EthosU55PipelineBI (
138- Sigmoid (), test_data , Sigmoid .aten_op , Sigmoid .exir_op , run_on_fvp = True
143+ Sigmoid (), ( test_data (),) , Sigmoid .aten_op , Sigmoid .exir_op , run_on_fvp = True
139144 )
140- pipeline .change_args ("quantize" , get_16bit_sigmoid_quantizer ("TOSA-0.80+BI+u55" ))
141- pipeline .change_args ("run_method_and_compare_outputs" , test_data , qtol = 1 )
145+ pipeline .change_args ("quantize" , get_32bit_sigmoid_quantizer ("TOSA-0.80+BI+u55" ))
142146 pipeline .run ()
143147
144148
@@ -153,29 +157,31 @@ def test_sigmoid_tosa_u55(test_data):
153157 "randn_neg" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
154158 "ramp" : "AssertionError: Output 0 does not match reference output. MLBEDSW-9770" ,
155159 },
160+ # int16 tables are not supported, but some tests happen to pass regardless.
161+ # Set them to xfail but strict=False -> ok if they pass.
162+ strict = False ,
156163)
157164@common .XfailIfNoCorstone300
158165def test_sigmoid_add_sigmoid_tosa_u55 (test_data ):
159166 pipeline = EthosU55PipelineBI (
160167 SigmoidAddSigmoid (),
161- test_data ,
168+ ( test_data (),) ,
162169 Sigmoid .aten_op ,
163170 Sigmoid .exir_op ,
164171 run_on_fvp = True ,
165172 )
166- pipeline .change_args ("quantize" , get_16bit_sigmoid_quantizer ("TOSA-0.80+BI+u55" ))
167- pipeline .change_args ("run_method_and_compare_outputs" , test_data , qtol = 1 )
173+ pipeline .change_args ("quantize" , get_32bit_sigmoid_quantizer ("TOSA-0.80+BI+u55" ))
168174 pipeline .run ()
169175
170176
171177@common .parametrize ("test_data" , test_data_suite )
178+ @pytest .mark .flaky (reruns = 5 )
172179@common .XfailIfNoCorstone320
173180def test_sigmoid_tosa_u85 (test_data ):
174181 pipeline = EthosU85PipelineBI (
175- Sigmoid (), test_data , Sigmoid .aten_op , Sigmoid .exir_op , run_on_fvp = True
182+ Sigmoid (), ( test_data (),) , Sigmoid .aten_op , Sigmoid .exir_op , run_on_fvp = True
176183 )
177- pipeline .change_args ("quantize" , get_16bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
178- pipeline .change_args ("run_method_and_compare_outputs" , test_data , qtol = 1 )
184+ pipeline .change_args ("quantize" , get_32bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
179185 pipeline .run ()
180186
181187
@@ -186,15 +192,15 @@ def test_sigmoid_tosa_u85(test_data):
186192 "ramp" : "AssertionError: Output 0 does not match reference output." ,
187193 },
188194)
195+ @pytest .mark .flaky (reruns = 5 )
189196@common .XfailIfNoCorstone320
190197def test_sigmoid_add_sigmoid_tosa_u85 (test_data ):
191198 pipeline = EthosU85PipelineBI (
192199 SigmoidAddSigmoid (),
193- test_data ,
200+ ( test_data (),) ,
194201 Sigmoid .aten_op ,
195202 Sigmoid .exir_op ,
196203 run_on_fvp = True ,
197204 )
198- pipeline .change_args ("quantize" , get_16bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
199- pipeline .change_args ("run_method_and_compare_outputs" , test_data , qtol = 1 )
205+ pipeline .change_args ("quantize" , get_32bit_sigmoid_quantizer ("TOSA-0.80+BI" ))
200206 pipeline .run ()
0 commit comments