99
1010from typing import Tuple
1111
12+ import pytest
13+
1214import torch
13- from executorch .backends .arm .test import common
15+ from executorch .backends .arm .test import common , conftest
1416from executorch .backends .arm .test .tester .arm_tester import ArmTester
1517from executorch .exir .backend .compile_spec_schema import CompileSpec
1618from parameterized import parameterized
@@ -63,7 +65,7 @@ def forward(self, x, y):
6365 def _test_sigmoid_tosa_MI_pipeline (
6466 self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
6567 ):
66- (
68+ tester = (
6769 ArmTester (
6870 module ,
6971 example_inputs = test_data ,
@@ -77,9 +79,11 @@ def _test_sigmoid_tosa_MI_pipeline(
7779 .check_not (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
7880 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
7981 .to_executorch ()
80- .run_method_and_compare_outputs (inputs = test_data )
8182 )
8283
84+ if conftest .is_option_enabled ("tosa_ref_model" ):
85+ tester .run_method_and_compare_outputs (inputs = test_data )
86+
8387 def _test_sigmoid_tosa_BI_pipeline (self , module : torch .nn .Module , test_data : Tuple ):
8488 (
8589 ArmTester (
@@ -96,7 +100,6 @@ def _test_sigmoid_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tup
96100 .check_not (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
97101 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
98102 .to_executorch ()
99- .run_method_and_compare_outputs (inputs = test_data )
100103 )
101104
102105 def _test_sigmoid_tosa_ethos_BI_pipeline (
@@ -137,6 +140,7 @@ def _test_sigmoid_tosa_u85_BI_pipeline(
137140 )
138141
139142 @parameterized .expand (test_data_suite )
143+ @pytest .mark .tosa_ref_model
140144 def test_sigmoid_tosa_MI (
141145 self ,
142146 test_name : str ,
@@ -145,26 +149,33 @@ def test_sigmoid_tosa_MI(
145149 self ._test_sigmoid_tosa_MI_pipeline (self .Sigmoid (), (test_data ,))
146150
147151 @parameterized .expand (test_data_suite )
152+ @pytest .mark .tosa_ref_model
148153 def test_sigmoid_tosa_BI (self , test_name : str , test_data : torch .Tensor ):
149154 self ._test_sigmoid_tosa_BI_pipeline (self .Sigmoid (), (test_data ,))
150155
156+ @pytest .mark .tosa_ref_model
151157 def test_add_sigmoid_tosa_MI (self ):
152158 self ._test_sigmoid_tosa_MI_pipeline (self .AddSigmoid (), (test_data_suite [0 ][1 ],))
153159
160+ @pytest .mark .tosa_ref_model
154161 def test_add_sigmoid_tosa_BI (self ):
155162 self ._test_sigmoid_tosa_BI_pipeline (self .AddSigmoid (), (test_data_suite [5 ][1 ],))
156163
164+ @pytest .mark .tosa_ref_model
157165 def test_sigmoid_add_tosa_MI (self ):
158166 self ._test_sigmoid_tosa_MI_pipeline (self .SigmoidAdd (), (test_data_suite [0 ][1 ],))
159167
168+ @pytest .mark .tosa_ref_model
160169 def test_sigmoid_add_tosa_BI (self ):
161170 self ._test_sigmoid_tosa_BI_pipeline (self .SigmoidAdd (), (test_data_suite [0 ][1 ],))
162171
172+ @pytest .mark .tosa_ref_model
163173 def test_sigmoid_add_sigmoid_tosa_MI (self ):
164174 self ._test_sigmoid_tosa_MI_pipeline (
165175 self .SigmoidAddSigmoid (), (test_data_suite [4 ][1 ], test_data_suite [3 ][1 ])
166176 )
167177
178+ @pytest .mark .tosa_ref_model
168179 def test_sigmoid_add_sigmoid_tosa_BI (self ):
169180 self ._test_sigmoid_tosa_BI_pipeline (
170181 self .SigmoidAddSigmoid (), (test_data_suite [4 ][1 ], test_data_suite [3 ][1 ])
0 commit comments