99
1010from typing import Tuple
1111
12+ import pytest
13+
1214import torch
1315
14- from executorch .backends .arm .test import common
16+ from executorch .backends .arm .test import common , conftest
1517from executorch .backends .arm .test .tester .arm_tester import ArmTester
1618from executorch .exir .backend .compile_spec_schema import CompileSpec
1719from parameterized import parameterized
@@ -40,7 +42,7 @@ def forward(self, x):
4042 def _test_tanh_tosa_MI_pipeline (
4143 self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
4244 ):
43- (
45+ tester = (
4446 ArmTester (
4547 module ,
4648 example_inputs = test_data ,
@@ -54,11 +56,13 @@ def _test_tanh_tosa_MI_pipeline(
5456 .check_not (["executorch_exir_dialects_edge__ops_aten_tanh_default" ])
5557 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
5658 .to_executorch ()
57- .run_method_and_compare_outputs (inputs = test_data )
5859 )
5960
61+ if conftest .is_option_enabled ("tosa_ref_model" ):
62+ tester .run_method_and_compare_outputs (inputs = test_data )
63+
6064 def _test_tanh_tosa_BI_pipeline (self , module : torch .nn .Module , test_data : Tuple ):
61- (
65+ tester = (
6266 ArmTester (
6367 module ,
6468 example_inputs = test_data ,
@@ -73,9 +77,11 @@ def _test_tanh_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple)
7377 .check_not (["executorch_exir_dialects_edge__ops_aten_tanh_default" ])
7478 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
7579 .to_executorch ()
76- .run_method_and_compare_outputs (inputs = test_data )
7780 )
7881
82+ if conftest .is_option_enabled ("tosa_ref_model" ):
83+ tester .run_method_and_compare_outputs (inputs = test_data )
84+
7985 def _test_tanh_tosa_ethos_BI_pipeline (
8086 self ,
8187 compile_spec : list [CompileSpec ],
@@ -114,6 +120,7 @@ def _test_tanh_tosa_u85_BI_pipeline(
114120 )
115121
116122 @parameterized .expand (test_data_suite )
123+ @pytest .mark .tosa_ref_model
117124 def test_tanh_tosa_MI (
118125 self ,
119126 test_name : str ,
@@ -122,6 +129,7 @@ def test_tanh_tosa_MI(
122129 self ._test_tanh_tosa_MI_pipeline (self .Tanh (), (test_data ,))
123130
124131 @parameterized .expand (test_data_suite )
132+ @pytest .mark .tosa_ref_model
125133 def test_tanh_tosa_BI (self , test_name : str , test_data : torch .Tensor ):
126134 self ._test_tanh_tosa_BI_pipeline (self .Tanh (), (test_data ,))
127135
0 commit comments