77import unittest
88from typing import Tuple
99
10+ import pytest
11+
1012import torch
1113
12- from executorch .backends .arm .test import common
14+ from executorch .backends .arm .test import common , conftest
1315from executorch .backends .arm .test .tester .arm_tester import ArmTester
1416from executorch .exir .backend .compile_spec_schema import CompileSpec
1517from parameterized import parameterized
@@ -35,7 +37,7 @@ def forward(self, x: torch.Tensor):
3537 def _test_slice_tosa_MI_pipeline (
3638 self , module : torch .nn .Module , test_data : torch .Tensor
3739 ):
38- (
40+ tester = (
3941 ArmTester (
4042 module ,
4143 example_inputs = test_data ,
@@ -48,14 +50,16 @@ def _test_slice_tosa_MI_pipeline(
4850 .partition ()
4951 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
5052 .to_executorch ()
51- .run_method_and_compare_outputs (inputs = test_data )
5253 )
5354
55+ if conftest .is_option_enabled ("tosa_ref_model" ):
56+ tester .run_method_and_compare_outputs (inputs = test_data )
57+
5458 def _test_slice_tosa_BI_pipeline (
5559 self , module : torch .nn .Module , test_data : Tuple [torch .Tensor ]
5660 ):
5761
58- (
62+ tester = (
5963 ArmTester (
6064 module ,
6165 example_inputs = test_data ,
@@ -68,9 +72,11 @@ def _test_slice_tosa_BI_pipeline(
6872 .partition ()
6973 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
7074 .to_executorch ()
71- .run_method_and_compare_outputs (inputs = test_data , qtol = 1 )
7275 )
7376
77+ if conftest .is_option_enabled ("tosa_ref_model" ):
78+ tester .run_method_and_compare_outputs (inputs = test_data , qtol = 1 )
79+
7480 def _test_slice_ethos_BI_pipeline (
7581 self ,
7682 compile_spec : list [CompileSpec ],
@@ -107,14 +113,17 @@ def _test_slice_u85_BI_pipeline(
107113 )
108114
109115 @parameterized .expand (Slice .test_tensors )
116+ @pytest .mark .tosa_ref_model
110117 def test_slice_tosa_MI (self , tensor ):
111118 self ._test_slice_tosa_MI_pipeline (self .Slice (), (tensor ,))
112119
113120 @parameterized .expand (Slice .test_tensors [:2 ])
121+ @pytest .mark .tosa_ref_model
114122 def test_slice_nchw_tosa_BI (self , test_tensor : torch .Tensor ):
115123 self ._test_slice_tosa_BI_pipeline (self .Slice (), (test_tensor ,))
116124
117125 @parameterized .expand (Slice .test_tensors [2 :])
126+ @pytest .mark .tosa_ref_model
118127 def test_slice_nhwc_tosa_BI (self , test_tensor : torch .Tensor ):
119128 self ._test_slice_tosa_BI_pipeline (self .Slice (), (test_tensor ,))
120129
0 commit comments