55# This source code is licensed under the BSD-style license found in the
66# LICENSE file in the root directory of this source tree.
77
8+ import pytest
89import unittest
910
1011from typing import Tuple
1112
1213import torch
13- from executorch .backends .arm .test import common
14+ from executorch .backends .arm .test import common , conftest
1415from executorch .backends .arm .test .tester .arm_tester import ArmTester
1516from executorch .exir .backend .compile_spec_schema import CompileSpec
1617from parameterized import parameterized
@@ -63,7 +64,7 @@ def forward(self, x, y):
6364 def _test_sigmoid_tosa_MI_pipeline (
6465 self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
6566 ):
66- (
67+ tester = (
6768 ArmTester (
6869 module ,
6970 example_inputs = test_data ,
@@ -77,11 +78,13 @@ def _test_sigmoid_tosa_MI_pipeline(
7778 .check_not (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
7879 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
7980 .to_executorch ()
80- .run_method_and_compare_outputs (inputs = test_data )
8181 )
8282
83+ if conftest .is_option_enabled ("tosa_ref_model" ):
84+ tester .run_method_and_compare_outputs (inputs = test_data )
85+
8386 def _test_sigmoid_tosa_BI_pipeline (self , module : torch .nn .Module , test_data : Tuple ):
84- (
87+ tester = (
8588 ArmTester (
8689 module ,
8790 example_inputs = test_data ,
@@ -96,16 +99,18 @@ def _test_sigmoid_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tup
9699 .check_not (["executorch_exir_dialects_edge__ops_aten_sigmoid_default" ])
97100 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
98101 .to_executorch ()
99- .run_method_and_compare_outputs (inputs = test_data )
100102 )
101103
104+ if conftest .is_option_enabled ("tosa_ref_model" ):
105+ tester .run_method_and_compare_outputs (inputs = test_data )
106+
102107 def _test_sigmoid_tosa_ethos_BI_pipeline (
103108 self ,
104109 compile_spec : list [CompileSpec ],
105110 module : torch .nn .Module ,
106111 test_data : Tuple [torch .tensor ],
107112 ):
108- (
113+ tester = (
109114 ArmTester (
110115 module ,
111116 example_inputs = test_data ,
@@ -122,6 +127,9 @@ def _test_sigmoid_tosa_ethos_BI_pipeline(
122127 .to_executorch ()
123128 )
124129
130+ if conftest .is_option_enabled ("tosa_ref_model" ):
131+ tester .run_method_and_compare_outputs (inputs = test_data )
132+
125133 def _test_sigmoid_tosa_u55_BI_pipeline (
126134 self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
127135 ):
@@ -137,6 +145,7 @@ def _test_sigmoid_tosa_u85_BI_pipeline(
137145 )
138146
139147 @parameterized .expand (test_data_suite )
148+ @pytest .mark .tosa_ref_model
140149 def test_sigmoid_tosa_MI (
141150 self ,
142151 test_name : str ,
@@ -145,26 +154,33 @@ def test_sigmoid_tosa_MI(
145154 self ._test_sigmoid_tosa_MI_pipeline (self .Sigmoid (), (test_data ,))
146155
147156 @parameterized .expand (test_data_suite )
157+ @pytest .mark .tosa_ref_model
148158 def test_sigmoid_tosa_BI (self , test_name : str , test_data : torch .Tensor ):
149159 self ._test_sigmoid_tosa_BI_pipeline (self .Sigmoid (), (test_data ,))
150160
161+ @pytest .mark .tosa_ref_model
151162 def test_add_sigmoid_tosa_MI (self ):
152163 self ._test_sigmoid_tosa_MI_pipeline (self .AddSigmoid (), (test_data_suite [0 ][1 ],))
153164
165+ @pytest .mark .tosa_ref_model
154166 def test_add_sigmoid_tosa_BI (self ):
155167 self ._test_sigmoid_tosa_BI_pipeline (self .AddSigmoid (), (test_data_suite [5 ][1 ],))
156168
169+ @pytest .mark .tosa_ref_model
157170 def test_sigmoid_add_tosa_MI (self ):
158171 self ._test_sigmoid_tosa_MI_pipeline (self .SigmoidAdd (), (test_data_suite [0 ][1 ],))
159172
173+ @pytest .mark .tosa_ref_model
160174 def test_sigmoid_add_tosa_BI (self ):
161175 self ._test_sigmoid_tosa_BI_pipeline (self .SigmoidAdd (), (test_data_suite [0 ][1 ],))
162176
177+ @pytest .mark .tosa_ref_model
163178 def test_sigmoid_add_sigmoid_tosa_MI (self ):
164179 self ._test_sigmoid_tosa_MI_pipeline (
165180 self .SigmoidAddSigmoid (), (test_data_suite [4 ][1 ], test_data_suite [3 ][1 ])
166181 )
167182
183+ @pytest .mark .tosa_ref_model
168184 def test_sigmoid_add_sigmoid_tosa_BI (self ):
169185 self ._test_sigmoid_tosa_BI_pipeline (
170186 self .SigmoidAddSigmoid (), (test_data_suite [4 ][1 ], test_data_suite [3 ][1 ])
0 commit comments