99
1010from typing import Tuple
1111
12+ import pytest
13+
1214import torch
13- from executorch .backends .arm .test import common
15+ from executorch .backends .arm .test import common , conftest
1416from executorch .backends .arm .test .tester .arm_tester import ArmTester
1517from executorch .exir .backend .compile_spec_schema import CompileSpec
1618from parameterized import parameterized
@@ -63,7 +65,7 @@ def forward(self, x, y):
6365 def _test_sigmoid_tosa_MI_pipeline (
6466 self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
6567 ):
66- (
68+ tester = (
6769 ArmTester (
6870 module ,
6971 example_inputs = test_data ,
@@ -77,11 +79,13 @@ def _test_sigmoid_tosa_MI_pipeline(
7779 .check_not (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
7880 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
7981 .to_executorch ()
80- .run_method_and_compare_outputs (inputs = test_data )
8182 )
8283
84+ if conftest .is_option_enabled ("tosa_ref_model" ):
85+ tester .run_method_and_compare_outputs (inputs = test_data )
86+
8387 def _test_sigmoid_tosa_BI_pipeline (self , module : torch .nn .Module , test_data : Tuple ):
84- (
88+ tester = (
8589 ArmTester (
8690 module ,
8791 example_inputs = test_data ,
@@ -96,9 +100,11 @@ def _test_sigmoid_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tup
96100 .check_not (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
97101 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
98102 .to_executorch ()
99- .run_method_and_compare_outputs (inputs = test_data )
100103 )
101104
105+ if conftest .is_option_enabled ("tosa_ref_model" ):
106+ tester .run_method_and_compare_outputs (inputs = test_data )
107+
102108 def _test_sigmoid_tosa_ethos_BI_pipeline (
103109 self ,
104110 compile_spec : list [CompileSpec ],
@@ -137,6 +143,7 @@ def _test_sigmoid_tosa_u85_BI_pipeline(
137143 )
138144
139145 @parameterized .expand (test_data_suite )
146+ @pytest .mark .tosa_ref_model
140147 def test_sigmoid_tosa_MI (
141148 self ,
142149 test_name : str ,
@@ -145,26 +152,33 @@ def test_sigmoid_tosa_MI(
145152 self ._test_sigmoid_tosa_MI_pipeline (self .Sigmoid (), (test_data ,))
146153
147154 @parameterized .expand (test_data_suite )
155+ @pytest .mark .tosa_ref_model
148156 def test_sigmoid_tosa_BI (self , test_name : str , test_data : torch .Tensor ):
149157 self ._test_sigmoid_tosa_BI_pipeline (self .Sigmoid (), (test_data ,))
150158
159+ @pytest .mark .tosa_ref_model
151160 def test_add_sigmoid_tosa_MI (self ):
152161 self ._test_sigmoid_tosa_MI_pipeline (self .AddSigmoid (), (test_data_suite [0 ][1 ],))
153162
163+ @pytest .mark .tosa_ref_model
154164 def test_add_sigmoid_tosa_BI (self ):
155165 self ._test_sigmoid_tosa_BI_pipeline (self .AddSigmoid (), (test_data_suite [5 ][1 ],))
156166
167+ @pytest .mark .tosa_ref_model
157168 def test_sigmoid_add_tosa_MI (self ):
158169 self ._test_sigmoid_tosa_MI_pipeline (self .SigmoidAdd (), (test_data_suite [0 ][1 ],))
159170
171+ @pytest .mark .tosa_ref_model
160172 def test_sigmoid_add_tosa_BI (self ):
161173 self ._test_sigmoid_tosa_BI_pipeline (self .SigmoidAdd (), (test_data_suite [0 ][1 ],))
162174
175+ @pytest .mark .tosa_ref_model
163176 def test_sigmoid_add_sigmoid_tosa_MI (self ):
164177 self ._test_sigmoid_tosa_MI_pipeline (
165178 self .SigmoidAddSigmoid (), (test_data_suite [4 ][1 ], test_data_suite [3 ][1 ])
166179 )
167180
181+ @pytest .mark .tosa_ref_model
168182 def test_sigmoid_add_sigmoid_tosa_BI (self ):
169183 self ._test_sigmoid_tosa_BI_pipeline (
170184 self .SigmoidAddSigmoid (), (test_data_suite [4 ][1 ], test_data_suite [3 ][1 ])
0 commit comments