11import importlib .util
22import os
3+ import re
34import subprocess
45import sys
56import unittest
67from unittest .mock import mock_open , patch
78
89import pytest
910
10- # Dynamically import the script
11- script_path = os .path .join (".ci" , "scripts" , "gather_benchmark_configs.py" )
12- spec = importlib .util .spec_from_file_location ("gather_benchmark_configs" , script_path )
13- gather_benchmark_configs = importlib .util .module_from_spec (spec )
14- spec .loader .exec_module (gather_benchmark_configs )
15-
1611
1712@pytest .mark .skipif (
1813 sys .platform != "linux" , reason = "The script under test runs on Linux runners only"
1914)
2015class TestGatehrBenchmarkConfigs (unittest .TestCase ):
2116
17+ @classmethod
18+ def setUpClass (cls ):
19+ # Dynamically import the script
20+ script_path = os .path .join (".ci" , "scripts" , "gather_benchmark_configs.py" )
21+ spec = importlib .util .spec_from_file_location (
22+ "gather_benchmark_configs" , script_path
23+ )
24+ cls .gather_benchmark_configs = importlib .util .module_from_spec (spec )
25+ spec .loader .exec_module (cls .gather_benchmark_configs )
26+
2227 def test_extract_all_configs_android (self ):
23- android_configs = gather_benchmark_configs .extract_all_configs (
24- gather_benchmark_configs .BENCHMARK_CONFIGS , "android"
28+ android_configs = self . gather_benchmark_configs .extract_all_configs (
29+ self . gather_benchmark_configs .BENCHMARK_CONFIGS , "android"
2530 )
2631 self .assertIn ("xnnpack_q8" , android_configs )
2732 self .assertIn ("qnn_q8" , android_configs )
2833 self .assertIn ("llama3_spinquant" , android_configs )
2934 self .assertIn ("llama3_qlora" , android_configs )
3035
3136 def test_extract_all_configs_ios (self ):
32- ios_configs = gather_benchmark_configs .extract_all_configs (
33- gather_benchmark_configs .BENCHMARK_CONFIGS , "ios"
37+ ios_configs = self . gather_benchmark_configs .extract_all_configs (
38+ self . gather_benchmark_configs .BENCHMARK_CONFIGS , "ios"
3439 )
3540
3641 self .assertIn ("xnnpack_q8" , ios_configs )
@@ -40,51 +45,114 @@ def test_extract_all_configs_ios(self):
4045 self .assertIn ("llama3_spinquant" , ios_configs )
4146 self .assertIn ("llama3_qlora" , ios_configs )
4247
48+ def test_skip_disabled_configs (self ):
49+ # Use patch as a context manager to avoid modifying DISABLED_CONFIGS and BENCHMARK_CONFIGS
50+ with patch .dict (
51+ self .gather_benchmark_configs .DISABLED_CONFIGS ,
52+ {
53+ "mv3" : [
54+ self .gather_benchmark_configs .DisabledConfig (
55+ config_name = "disabled_config1" ,
56+ github_issue = "https://github.com/org/repo/issues/123" ,
57+ ),
58+ self .gather_benchmark_configs .DisabledConfig (
59+ config_name = "disabled_config2" ,
60+ github_issue = "https://github.com/org/repo/issues/124" ,
61+ ),
62+ ]
63+ },
64+ ), patch .dict (
65+ self .gather_benchmark_configs .BENCHMARK_CONFIGS ,
66+ {
67+ "ios" : [
68+ "disabled_config1" ,
69+ "disabled_config2" ,
70+ "enabled_config1" ,
71+ "enabled_config2" ,
72+ ]
73+ },
74+ ):
75+ result = self .gather_benchmark_configs .generate_compatible_configs (
76+ "mv3" , target_os = "ios"
77+ )
78+
79+ # Assert that disabled configs are excluded
80+ self .assertNotIn ("disabled_config1" , result )
81+ self .assertNotIn ("disabled_config2" , result )
82+ # Assert enabled configs are included
83+ self .assertIn ("enabled_config1" , result )
84+ self .assertIn ("enabled_config2" , result )
85+
86+ def test_disabled_configs_have_github_links (self ):
87+ github_issue_regex = re .compile (r"https://github\.com/.+/.+/issues/\d+" )
88+
89+ for (
90+ model_name ,
91+ disabled_configs ,
92+ ) in self .gather_benchmark_configs .DISABLED_CONFIGS .items ():
93+ for disabled in disabled_configs :
94+ with self .subTest (model_name = model_name , config = disabled .config_name ):
95+ # Assert that disabled is an instance of DisabledConfig
96+ self .assertIsInstance (
97+ disabled , self .gather_benchmark_configs .DisabledConfig
98+ )
99+
100+ # Assert that github_issue is provided and matches the expected pattern
101+ self .assertTrue (
102+ disabled .github_issue
103+ and github_issue_regex .match (disabled .github_issue ),
104+ f"Invalid or missing GitHub issue link for '{ disabled .config_name } ' in model '{ model_name } '." ,
105+ )
106+
43107 def test_generate_compatible_configs_llama_model (self ):
44108 model_name = "meta-llama/Llama-3.2-1B"
45109 target_os = "ios"
46- result = gather_benchmark_configs .generate_compatible_configs (
110+ result = self . gather_benchmark_configs .generate_compatible_configs (
47111 model_name , target_os
48112 )
49113 expected = ["llama3_fb16" , "llama3_coreml_ane" ]
50114 self .assertEqual (result , expected )
51115
52116 target_os = "android"
53- result = gather_benchmark_configs .generate_compatible_configs (
117+ result = self . gather_benchmark_configs .generate_compatible_configs (
54118 model_name , target_os
55119 )
56120 expected = ["llama3_fb16" ]
57121 self .assertEqual (result , expected )
58122
59123 def test_generate_compatible_configs_quantized_llama_model (self ):
60124 model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
61- result = gather_benchmark_configs .generate_compatible_configs (model_name , None )
125+ result = self .gather_benchmark_configs .generate_compatible_configs (
126+ model_name , None
127+ )
62128 expected = ["llama3_spinquant" ]
63129 self .assertEqual (result , expected )
64130
65131 model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
66- result = gather_benchmark_configs .generate_compatible_configs (model_name , None )
132+ result = self .gather_benchmark_configs .generate_compatible_configs (
133+ model_name , None
134+ )
67135 expected = ["llama3_qlora" ]
68136 self .assertEqual (result , expected )
69137
70138 def test_generate_compatible_configs_non_genai_model (self ):
71139 model_name = "mv2"
72140 target_os = "xplat"
73- result = gather_benchmark_configs .generate_compatible_configs (
141+ result = self . gather_benchmark_configs .generate_compatible_configs (
74142 model_name , target_os
75143 )
76144 expected = ["xnnpack_q8" ]
77145 self .assertEqual (result , expected )
78146
79147 target_os = "android"
80- result = gather_benchmark_configs .generate_compatible_configs (
148+ result = self . gather_benchmark_configs .generate_compatible_configs (
81149 model_name , target_os
82150 )
83151 expected = ["xnnpack_q8" , "qnn_q8" ]
84152 self .assertEqual (result , expected )
85153
86154 target_os = "ios"
87- result = gather_benchmark_configs .generate_compatible_configs (
155+ result = self . gather_benchmark_configs .generate_compatible_configs (
88156 model_name , target_os
89157 )
90158 expected = ["xnnpack_q8" , "coreml_fp16" , "mps" ]
@@ -93,22 +161,22 @@ def test_generate_compatible_configs_non_genai_model(self):
93161 def test_generate_compatible_configs_unknown_model (self ):
94162 model_name = "unknown_model"
95163 target_os = "ios"
96- result = gather_benchmark_configs .generate_compatible_configs (
164+ result = self . gather_benchmark_configs .generate_compatible_configs (
97165 model_name , target_os
98166 )
99167 self .assertEqual (result , [])
100168
101169 def test_is_valid_huggingface_model_id_valid (self ):
102170 valid_model = "meta-llama/Llama-3.2-1B"
103171 self .assertTrue (
104- gather_benchmark_configs .is_valid_huggingface_model_id (valid_model )
172+ self . gather_benchmark_configs .is_valid_huggingface_model_id (valid_model )
105173 )
106174
107175 @patch ("builtins.open" , new_callable = mock_open )
108176 @patch ("os.getenv" , return_value = None )
109177 def test_set_output_no_github_env (self , mock_getenv , mock_file ):
110178 with patch ("builtins.print" ) as mock_print :
111- gather_benchmark_configs .set_output ("test_name" , "test_value" )
179+ self . gather_benchmark_configs .set_output ("test_name" , "test_value" )
112180 mock_print .assert_called_with ("::set-output name=test_name::test_value" )
113181
114182 def test_device_pools_contains_all_devices (self ):
@@ -120,7 +188,7 @@ def test_device_pools_contains_all_devices(self):
120188 "google_pixel_8_pro" ,
121189 ]
122190 for device in expected_devices :
123- self .assertIn (device , gather_benchmark_configs .DEVICE_POOLS )
191+ self .assertIn (device , self . gather_benchmark_configs .DEVICE_POOLS )
124192
125193 def test_gather_benchmark_configs_cli (self ):
126194 args = {
0 commit comments