99
1010from typing import Tuple
1111
12+ import pytest
13+
1214import torch
13- from executorch .backends .arm .test import common
15+ from executorch .backends .arm .test import common , conftest
1416from executorch .backends .arm .test .tester .arm_tester import ArmTester
1517from executorch .exir .backend .compile_spec_schema import CompileSpec
1618from parameterized import parameterized
@@ -63,7 +65,7 @@ def forward(self, x, y):
6365 def _test_sigmoid_tosa_MI_pipeline (
6466 self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
6567 ):
66- (
68+ tester = (
6769 ArmTester (
6870 module ,
6971 example_inputs = test_data ,
@@ -77,11 +79,13 @@ def _test_sigmoid_tosa_MI_pipeline(
7779 .check_not (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
7880 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
7981 .to_executorch ()
80- .run_method_and_compare_outputs (inputs = test_data )
8182 )
8283
84+ if conftest .is_option_enabled ("tosa_ref_model" ):
85+ tester .run_method_and_compare_outputs (inputs = test_data )
86+
8387 def _test_sigmoid_tosa_BI_pipeline (self , module : torch .nn .Module , test_data : Tuple ):
84- (
88+ tester = (
8589 ArmTester (
8690 module ,
8791 example_inputs = test_data ,
@@ -96,16 +100,18 @@ def _test_sigmoid_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tup
96100 .check_not (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
97101 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
98102 .to_executorch ()
99- .run_method_and_compare_outputs (inputs = test_data )
100103 )
101104
105+ if conftest .is_option_enabled ("tosa_ref_model" ):
106+ tester .run_method_and_compare_outputs (inputs = test_data )
107+
102108 def _test_sigmoid_tosa_ethos_BI_pipeline (
103109 self ,
104110 compile_spec : list [CompileSpec ],
105111 module : torch .nn .Module ,
106112 test_data : Tuple [torch .tensor ],
107113 ):
108- (
114+ tester = (
109115 ArmTester (
110116 module ,
111117 example_inputs = test_data ,
@@ -122,6 +128,9 @@ def _test_sigmoid_tosa_ethos_BI_pipeline(
122128 .to_executorch ()
123129 )
124130
131+ if conftest .is_option_enabled ("tosa_ref_model" ):
132+ tester .run_method_and_compare_outputs (inputs = test_data )
133+
125134 def _test_sigmoid_tosa_u55_BI_pipeline (
126135 self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
127136 ):
@@ -137,6 +146,7 @@ def _test_sigmoid_tosa_u85_BI_pipeline(
137146 )
138147
139148 @parameterized .expand (test_data_suite )
149+ @pytest .mark .tosa_ref_model
140150 def test_sigmoid_tosa_MI (
141151 self ,
142152 test_name : str ,
@@ -145,26 +155,33 @@ def test_sigmoid_tosa_MI(
145155 self ._test_sigmoid_tosa_MI_pipeline (self .Sigmoid (), (test_data ,))
146156
147157 @parameterized .expand (test_data_suite )
158+ @pytest .mark .tosa_ref_model
148159 def test_sigmoid_tosa_BI (self , test_name : str , test_data : torch .Tensor ):
149160 self ._test_sigmoid_tosa_BI_pipeline (self .Sigmoid (), (test_data ,))
150161
162+ @pytest .mark .tosa_ref_model
151163 def test_add_sigmoid_tosa_MI (self ):
152164 self ._test_sigmoid_tosa_MI_pipeline (self .AddSigmoid (), (test_data_suite [0 ][1 ],))
153165
166+ @pytest .mark .tosa_ref_model
154167 def test_add_sigmoid_tosa_BI (self ):
155168 self ._test_sigmoid_tosa_BI_pipeline (self .AddSigmoid (), (test_data_suite [5 ][1 ],))
156169
170+ @pytest .mark .tosa_ref_model
157171 def test_sigmoid_add_tosa_MI (self ):
158172 self ._test_sigmoid_tosa_MI_pipeline (self .SigmoidAdd (), (test_data_suite [0 ][1 ],))
159173
174+ @pytest .mark .tosa_ref_model
160175 def test_sigmoid_add_tosa_BI (self ):
161176 self ._test_sigmoid_tosa_BI_pipeline (self .SigmoidAdd (), (test_data_suite [0 ][1 ],))
162177
178+ @pytest .mark .tosa_ref_model
163179 def test_sigmoid_add_sigmoid_tosa_MI (self ):
164180 self ._test_sigmoid_tosa_MI_pipeline (
165181 self .SigmoidAddSigmoid (), (test_data_suite [4 ][1 ], test_data_suite [3 ][1 ])
166182 )
167183
184+ @pytest .mark .tosa_ref_model
168185 def test_sigmoid_add_sigmoid_tosa_BI (self ):
169186 self ._test_sigmoid_tosa_BI_pipeline (
170187 self .SigmoidAddSigmoid (), (test_data_suite [4 ][1 ], test_data_suite [3 ][1 ])
0 commit comments