99
1010from typing import Tuple
1111
12+ import pytest
13+
1214import torch
1315
14- from executorch .backends .arm .test import common
16+ from executorch .backends .arm .test import common , conftest
1517
1618from executorch .backends .arm .test .tester .test_pipeline import (
1719 EthosU55PipelineBI ,
@@ -64,15 +66,24 @@ def forward(self, x):
6466
6567
6668@common .parametrize ("test_module" , test_modules )
69+ @pytest .mark .tosa_ref_model
6770def test_avgpool2d_tosa_MI (test_module ):
6871 model , input_tensor = test_module
6972
70- pipeline = TosaPipelineMI [input_t ](model , input_tensor , aten_op , exir_op )
71- pipeline .change_args ("run_method_and_compare_outputs" , qtol = 1 , atol = 1 , rtol = 1 )
72- pipeline .run ()
73+ pipeline = TosaPipelineMI [input_t ](
74+ model ,
75+ input_tensor ,
76+ aten_op ,
77+ exir_op ,
78+ run_on_tosa_ref_model = conftest .is_option_enabled ("tosa_ref_model" ),
79+ )
80+ if conftest .is_option_enabled ("tosa_ref_model" ):
81+ pipeline .change_args ("run_method_and_compare_outputs" , qtol = 1 , atol = 1 , rtol = 1 )
82+ pipeline .run ()
7383
7484
7585@common .parametrize ("test_module" , test_modules )
86+ @pytest .mark .tosa_ref_model
7687def test_avgpool2d_tosa_BI (test_module ):
7788 model , input_tensor = test_module
7889
@@ -82,9 +93,11 @@ def test_avgpool2d_tosa_BI(test_module):
8293 aten_op ,
8394 exir_op ,
8495 symmetric_io_quantization = True ,
96+ run_on_tosa_ref_model = conftest .is_option_enabled ("tosa_ref_model" ),
8597 )
86- pipeline .change_args ("run_method_and_compare_outputs" , qtol = 1 , atol = 1 , rtol = 1 )
87- pipeline .run ()
98+ if conftest .is_option_enabled ("tosa_ref_model" ):
99+ pipeline .change_args ("run_method_and_compare_outputs" , qtol = 1 , atol = 1 , rtol = 1 )
100+ pipeline .run ()
88101
89102
90103@common .parametrize ("test_module" , test_modules )
0 commit comments