7
7
import unittest
8
8
from typing import Tuple
9
9
10
+ import pytest
11
+
10
12
import torch
11
13
12
- from executorch .backends .arm .test import common
14
+ from executorch .backends .arm .test import common , conftest
13
15
from executorch .backends .arm .test .tester .arm_tester import ArmTester
14
16
from executorch .exir .backend .compile_spec_schema import CompileSpec
15
17
from parameterized import parameterized
@@ -35,7 +37,7 @@ def forward(self, x: torch.Tensor):
35
37
def _test_slice_tosa_MI_pipeline (
36
38
self , module : torch .nn .Module , test_data : torch .Tensor
37
39
):
38
- (
40
+ tester = (
39
41
ArmTester (
40
42
module ,
41
43
example_inputs = test_data ,
@@ -48,14 +50,16 @@ def _test_slice_tosa_MI_pipeline(
48
50
.partition ()
49
51
.check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
50
52
.to_executorch ()
51
- .run_method_and_compare_outputs (inputs = test_data )
52
53
)
53
54
55
+ if conftest .is_option_enabled ("tosa_ref_model" ):
56
+ tester .run_method_and_compare_outputs (inputs = test_data )
57
+
54
58
def _test_slice_tosa_BI_pipeline (
55
59
self , module : torch .nn .Module , test_data : Tuple [torch .Tensor ]
56
60
):
57
61
58
- (
62
+ tester = (
59
63
ArmTester (
60
64
module ,
61
65
example_inputs = test_data ,
@@ -68,9 +72,11 @@ def _test_slice_tosa_BI_pipeline(
68
72
.partition ()
69
73
.check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
70
74
.to_executorch ()
71
- .run_method_and_compare_outputs (inputs = test_data , qtol = 1 )
72
75
)
73
76
77
+ if conftest .is_option_enabled ("tosa_ref_model" ):
78
+ tester .run_method_and_compare_outputs (inputs = test_data , qtol = 1 )
79
+
74
80
def _test_slice_ethos_BI_pipeline (
75
81
self ,
76
82
compile_spec : list [CompileSpec ],
@@ -107,14 +113,17 @@ def _test_slice_u85_BI_pipeline(
107
113
)
108
114
109
115
@parameterized .expand (Slice .test_tensors )
116
+ @pytest .mark .tosa_ref_model
110
117
def test_slice_tosa_MI (self , tensor ):
111
118
self ._test_slice_tosa_MI_pipeline (self .Slice (), (tensor ,))
112
119
113
120
@parameterized .expand (Slice .test_tensors [:2 ])
121
+ @pytest .mark .tosa_ref_model
114
122
def test_slice_nchw_tosa_BI (self , test_tensor : torch .Tensor ):
115
123
self ._test_slice_tosa_BI_pipeline (self .Slice (), (test_tensor ,))
116
124
117
125
@parameterized .expand (Slice .test_tensors [2 :])
126
+ @pytest .mark .tosa_ref_model
118
127
def test_slice_nhwc_tosa_BI (self , test_tensor : torch .Tensor ):
119
128
self ._test_slice_tosa_BI_pipeline (self .Slice (), (test_tensor ,))
120
129
0 commit comments