77
88import pytest
99
10- # Dynamically import the script
11- script_path = os .path .join (".ci" , "scripts" , "gather_benchmark_configs.py" )
12- spec = importlib .util .spec_from_file_location ("gather_benchmark_configs" , script_path )
13- gather_benchmark_configs = importlib .util .module_from_spec (spec )
14- spec .loader .exec_module (gather_benchmark_configs )
15-
1610
1711@pytest .mark .skipif (
1812 sys .platform != "linux" , reason = "The script under test runs on Linux runners only"
1913)
2014class TestGatehrBenchmarkConfigs (unittest .TestCase ):
2115
16+ @classmethod
17+ def setUpClass (cls ):
18+ # Dynamically import the script
19+ script_path = os .path .join (".ci" , "scripts" , "gather_benchmark_configs.py" )
20+ spec = importlib .util .spec_from_file_location (
21+ "gather_benchmark_configs" , script_path
22+ )
23+ cls .gather_benchmark_configs = importlib .util .module_from_spec (spec )
24+ spec .loader .exec_module (cls .gather_benchmark_configs )
25+
2226 def test_extract_all_configs_android (self ):
23- android_configs = gather_benchmark_configs .extract_all_configs (
24- gather_benchmark_configs .BENCHMARK_CONFIGS , "android"
27+ android_configs = self . gather_benchmark_configs .extract_all_configs (
28+ self . gather_benchmark_configs .BENCHMARK_CONFIGS , "android"
2529 )
2630 self .assertIn ("xnnpack_q8" , android_configs )
2731 self .assertIn ("qnn_q8" , android_configs )
2832 self .assertIn ("llama3_spinquant" , android_configs )
2933 self .assertIn ("llama3_qlora" , android_configs )
3034
3135 def test_extract_all_configs_ios (self ):
32- ios_configs = gather_benchmark_configs .extract_all_configs (
33- gather_benchmark_configs .BENCHMARK_CONFIGS , "ios"
36+ ios_configs = self . gather_benchmark_configs .extract_all_configs (
37+ self . gather_benchmark_configs .BENCHMARK_CONFIGS , "ios"
3438 )
3539
3640 self .assertIn ("xnnpack_q8" , ios_configs )
@@ -40,51 +44,82 @@ def test_extract_all_configs_ios(self):
4044 self .assertIn ("llama3_spinquant" , ios_configs )
4145 self .assertIn ("llama3_qlora" , ios_configs )
4246
47+ def test_skip_disabled_configs (self ):
48+ # Use patch as a context manager to avoid modifying DISABLED_CONFIGS and BENCHMARK_CONFIGS
49+ with patch .dict (
50+ self .gather_benchmark_configs .DISABLED_CONFIGS ,
51+ {"mv3" : ["disabled_config1" , "disabled_config2" ]},
52+ ), patch .dict (
53+ self .gather_benchmark_configs .BENCHMARK_CONFIGS ,
54+ {
55+ "ios" : [
56+ "disabled_config1" ,
57+ "disabled_config2" ,
58+ "enabled_config1" ,
59+ "enabled_config2" ,
60+ ]
61+ },
62+ ):
63+ result = self .gather_benchmark_configs .generate_compatible_configs (
64+ "mv3" , target_os = "ios"
65+ )
66+
67+ # Assert that disabled configs are excluded
68+ self .assertNotIn ("disabled_config1" , result )
69+ self .assertNotIn ("disabled_config2" , result )
70+ # Assert enabled configs are included
71+ self .assertIn ("enabled_config1" , result )
72+ self .assertIn ("enabled_config2" , result )
73+
4374 def test_generate_compatible_configs_llama_model (self ):
4475 model_name = "meta-llama/Llama-3.2-1B"
4576 target_os = "ios"
46- result = gather_benchmark_configs .generate_compatible_configs (
77+ result = self . gather_benchmark_configs .generate_compatible_configs (
4778 model_name , target_os
4879 )
4980 expected = ["llama3_fb16" , "llama3_coreml_ane" ]
5081 self .assertEqual (result , expected )
5182
5283 target_os = "android"
53- result = gather_benchmark_configs .generate_compatible_configs (
84+ result = self . gather_benchmark_configs .generate_compatible_configs (
5485 model_name , target_os
5586 )
5687 expected = ["llama3_fb16" ]
5788 self .assertEqual (result , expected )
5889
5990 def test_generate_compatible_configs_quantized_llama_model (self ):
6091 model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
61- result = gather_benchmark_configs .generate_compatible_configs (model_name , None )
92+ result = self .gather_benchmark_configs .generate_compatible_configs (
93+ model_name , None
94+ )
6295 expected = ["llama3_spinquant" ]
6396 self .assertEqual (result , expected )
6497
6598 model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
66- result = gather_benchmark_configs .generate_compatible_configs (model_name , None )
99+ result = self .gather_benchmark_configs .generate_compatible_configs (
100+ model_name , None
101+ )
67102 expected = ["llama3_qlora" ]
68103 self .assertEqual (result , expected )
69104
70105 def test_generate_compatible_configs_non_genai_model (self ):
71106 model_name = "mv2"
72107 target_os = "xplat"
73- result = gather_benchmark_configs .generate_compatible_configs (
108+ result = self . gather_benchmark_configs .generate_compatible_configs (
74109 model_name , target_os
75110 )
76111 expected = ["xnnpack_q8" ]
77112 self .assertEqual (result , expected )
78113
79114 target_os = "android"
80- result = gather_benchmark_configs .generate_compatible_configs (
115+ result = self . gather_benchmark_configs .generate_compatible_configs (
81116 model_name , target_os
82117 )
83118 expected = ["xnnpack_q8" , "qnn_q8" ]
84119 self .assertEqual (result , expected )
85120
86121 target_os = "ios"
87- result = gather_benchmark_configs .generate_compatible_configs (
122+ result = self . gather_benchmark_configs .generate_compatible_configs (
88123 model_name , target_os
89124 )
90125 expected = ["xnnpack_q8" , "coreml_fp16" , "mps" ]
@@ -93,22 +128,22 @@ def test_generate_compatible_configs_non_genai_model(self):
93128 def test_generate_compatible_configs_unknown_model (self ):
94129 model_name = "unknown_model"
95130 target_os = "ios"
96- result = gather_benchmark_configs .generate_compatible_configs (
131+ result = self . gather_benchmark_configs .generate_compatible_configs (
97132 model_name , target_os
98133 )
99134 self .assertEqual (result , [])
100135
101136 def test_is_valid_huggingface_model_id_valid (self ):
102137 valid_model = "meta-llama/Llama-3.2-1B"
103138 self .assertTrue (
104- gather_benchmark_configs .is_valid_huggingface_model_id (valid_model )
139+ self . gather_benchmark_configs .is_valid_huggingface_model_id (valid_model )
105140 )
106141
107142 @patch ("builtins.open" , new_callable = mock_open )
108143 @patch ("os.getenv" , return_value = None )
109144 def test_set_output_no_github_env (self , mock_getenv , mock_file ):
110145 with patch ("builtins.print" ) as mock_print :
111- gather_benchmark_configs .set_output ("test_name" , "test_value" )
146+ self . gather_benchmark_configs .set_output ("test_name" , "test_value" )
112147 mock_print .assert_called_with ("::set-output name=test_name::test_value" )
113148
114149 def test_device_pools_contains_all_devices (self ):
@@ -120,7 +155,7 @@ def test_device_pools_contains_all_devices(self):
120155 "google_pixel_8_pro" ,
121156 ]
122157 for device in expected_devices :
123- self .assertIn (device , gather_benchmark_configs .DEVICE_POOLS )
158+ self .assertIn (device , self . gather_benchmark_configs .DEVICE_POOLS )
124159
125160 def test_gather_benchmark_configs_cli (self ):
126161 args = {
0 commit comments