11import  importlib .util 
22import  os 
3+ import  re 
34import  subprocess 
45import  sys 
56import  unittest 
67from  unittest .mock  import  mock_open , patch 
78
89import  pytest 
910
10- # Dynamically import the script 
11- script_path  =  os .path .join (".ci" , "scripts" , "gather_benchmark_configs.py" )
12- spec  =  importlib .util .spec_from_file_location ("gather_benchmark_configs" , script_path )
13- gather_benchmark_configs  =  importlib .util .module_from_spec (spec )
14- spec .loader .exec_module (gather_benchmark_configs )
15- 
1611
1712@pytest .mark .skipif ( 
1813    sys .platform  !=  "linux" , reason = "The script under test runs on Linux runners only"  
1914) 
2015class  TestGatehrBenchmarkConfigs (unittest .TestCase ):
2116
17+     @classmethod  
18+     def  setUpClass (cls ):
19+         # Dynamically import the script 
20+         script_path  =  os .path .join (".ci" , "scripts" , "gather_benchmark_configs.py" )
21+         spec  =  importlib .util .spec_from_file_location (
22+             "gather_benchmark_configs" , script_path 
23+         )
24+         cls .gather_benchmark_configs  =  importlib .util .module_from_spec (spec )
25+         spec .loader .exec_module (cls .gather_benchmark_configs )
26+ 
2227    def  test_extract_all_configs_android (self ):
23-         android_configs  =  gather_benchmark_configs .extract_all_configs (
24-             gather_benchmark_configs .BENCHMARK_CONFIGS , "android" 
28+         android_configs  =  self . gather_benchmark_configs .extract_all_configs (
29+             self . gather_benchmark_configs .BENCHMARK_CONFIGS , "android" 
2530        )
2631        self .assertIn ("xnnpack_q8" , android_configs )
2732        self .assertIn ("qnn_q8" , android_configs )
2833        self .assertIn ("llama3_spinquant" , android_configs )
2934        self .assertIn ("llama3_qlora" , android_configs )
3035
3136    def  test_extract_all_configs_ios (self ):
32-         ios_configs  =  gather_benchmark_configs .extract_all_configs (
33-             gather_benchmark_configs .BENCHMARK_CONFIGS , "ios" 
37+         ios_configs  =  self . gather_benchmark_configs .extract_all_configs (
38+             self . gather_benchmark_configs .BENCHMARK_CONFIGS , "ios" 
3439        )
3540
3641        self .assertIn ("xnnpack_q8" , ios_configs )
@@ -40,51 +45,114 @@ def test_extract_all_configs_ios(self):
4045        self .assertIn ("llama3_spinquant" , ios_configs )
4146        self .assertIn ("llama3_qlora" , ios_configs )
4247
48+     def  test_skip_disabled_configs (self ):
49+         # Use patch as a context manager to avoid modifying DISABLED_CONFIGS and BENCHMARK_CONFIGS 
50+         with  patch .dict (
51+             self .gather_benchmark_configs .DISABLED_CONFIGS ,
52+             {
53+                 "mv3" : [
54+                     self .gather_benchmark_configs .DisabledConfig (
55+                         config_name = "disabled_config1" ,
56+                         github_issue = "https://github.com/org/repo/issues/123" ,
57+                     ),
58+                     self .gather_benchmark_configs .DisabledConfig (
59+                         config_name = "disabled_config2" ,
60+                         github_issue = "https://github.com/org/repo/issues/124" ,
61+                     ),
62+                 ]
63+             },
64+         ), patch .dict (
65+             self .gather_benchmark_configs .BENCHMARK_CONFIGS ,
66+             {
67+                 "ios" : [
68+                     "disabled_config1" ,
69+                     "disabled_config2" ,
70+                     "enabled_config1" ,
71+                     "enabled_config2" ,
72+                 ]
73+             },
74+         ):
75+             result  =  self .gather_benchmark_configs .generate_compatible_configs (
76+                 "mv3" , target_os = "ios" 
77+             )
78+ 
79+             # Assert that disabled configs are excluded 
80+             self .assertNotIn ("disabled_config1" , result )
81+             self .assertNotIn ("disabled_config2" , result )
82+             # Assert enabled configs are included 
83+             self .assertIn ("enabled_config1" , result )
84+             self .assertIn ("enabled_config2" , result )
85+ 
86+     def  test_disabled_configs_have_github_links (self ):
87+         github_issue_regex  =  re .compile (r"https://github\.com/.+/.+/issues/\d+" )
88+ 
89+         for  (
90+             model_name ,
91+             disabled_configs ,
92+         ) in  self .gather_benchmark_configs .DISABLED_CONFIGS .items ():
93+             for  disabled  in  disabled_configs :
94+                 with  self .subTest (model_name = model_name , config = disabled .config_name ):
95+                     # Assert that disabled is an instance of DisabledConfig 
96+                     self .assertIsInstance (
97+                         disabled , self .gather_benchmark_configs .DisabledConfig 
98+                     )
99+ 
100+                     # Assert that github_issue is provided and matches the expected pattern 
101+                     self .assertTrue (
102+                         disabled .github_issue 
103+                         and  github_issue_regex .match (disabled .github_issue ),
104+                         f"Invalid or missing GitHub issue link for '{ disabled .config_name } { model_name }  ,
105+                     )
106+ 
43107    def  test_generate_compatible_configs_llama_model (self ):
44108        model_name  =  "meta-llama/Llama-3.2-1B" 
45109        target_os  =  "ios" 
46-         result  =  gather_benchmark_configs .generate_compatible_configs (
110+         result  =  self . gather_benchmark_configs .generate_compatible_configs (
47111            model_name , target_os 
48112        )
49113        expected  =  ["llama3_fb16" , "llama3_coreml_ane" ]
50114        self .assertEqual (result , expected )
51115
52116        target_os  =  "android" 
53-         result  =  gather_benchmark_configs .generate_compatible_configs (
117+         result  =  self . gather_benchmark_configs .generate_compatible_configs (
54118            model_name , target_os 
55119        )
56120        expected  =  ["llama3_fb16" ]
57121        self .assertEqual (result , expected )
58122
59123    def  test_generate_compatible_configs_quantized_llama_model (self ):
60124        model_name  =  "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8" 
61-         result  =  gather_benchmark_configs .generate_compatible_configs (model_name , None )
125+         result  =  self .gather_benchmark_configs .generate_compatible_configs (
126+             model_name , None 
127+         )
62128        expected  =  ["llama3_spinquant" ]
63129        self .assertEqual (result , expected )
64130
65131        model_name  =  "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8" 
66-         result  =  gather_benchmark_configs .generate_compatible_configs (model_name , None )
132+         result  =  self .gather_benchmark_configs .generate_compatible_configs (
133+             model_name , None 
134+         )
67135        expected  =  ["llama3_qlora" ]
68136        self .assertEqual (result , expected )
69137
70138    def  test_generate_compatible_configs_non_genai_model (self ):
71139        model_name  =  "mv2" 
72140        target_os  =  "xplat" 
73-         result  =  gather_benchmark_configs .generate_compatible_configs (
141+         result  =  self . gather_benchmark_configs .generate_compatible_configs (
74142            model_name , target_os 
75143        )
76144        expected  =  ["xnnpack_q8" ]
77145        self .assertEqual (result , expected )
78146
79147        target_os  =  "android" 
80-         result  =  gather_benchmark_configs .generate_compatible_configs (
148+         result  =  self . gather_benchmark_configs .generate_compatible_configs (
81149            model_name , target_os 
82150        )
83151        expected  =  ["xnnpack_q8" , "qnn_q8" ]
84152        self .assertEqual (result , expected )
85153
86154        target_os  =  "ios" 
87-         result  =  gather_benchmark_configs .generate_compatible_configs (
155+         result  =  self . gather_benchmark_configs .generate_compatible_configs (
88156            model_name , target_os 
89157        )
90158        expected  =  ["xnnpack_q8" , "coreml_fp16" , "mps" ]
@@ -93,22 +161,22 @@ def test_generate_compatible_configs_non_genai_model(self):
93161    def  test_generate_compatible_configs_unknown_model (self ):
94162        model_name  =  "unknown_model" 
95163        target_os  =  "ios" 
96-         result  =  gather_benchmark_configs .generate_compatible_configs (
164+         result  =  self . gather_benchmark_configs .generate_compatible_configs (
97165            model_name , target_os 
98166        )
99167        self .assertEqual (result , [])
100168
101169    def  test_is_valid_huggingface_model_id_valid (self ):
102170        valid_model  =  "meta-llama/Llama-3.2-1B" 
103171        self .assertTrue (
104-             gather_benchmark_configs .is_valid_huggingface_model_id (valid_model )
172+             self . gather_benchmark_configs .is_valid_huggingface_model_id (valid_model )
105173        )
106174
107175    @patch ("builtins.open" , new_callable = mock_open ) 
108176    @patch ("os.getenv" , return_value = None ) 
109177    def  test_set_output_no_github_env (self , mock_getenv , mock_file ):
110178        with  patch ("builtins.print" ) as  mock_print :
111-             gather_benchmark_configs .set_output ("test_name" , "test_value" )
179+             self . gather_benchmark_configs .set_output ("test_name" , "test_value" )
112180            mock_print .assert_called_with ("::set-output name=test_name::test_value" )
113181
114182    def  test_device_pools_contains_all_devices (self ):
@@ -120,7 +188,7 @@ def test_device_pools_contains_all_devices(self):
120188            "google_pixel_8_pro" ,
121189        ]
122190        for  device  in  expected_devices :
123-             self .assertIn (device , gather_benchmark_configs .DEVICE_POOLS )
191+             self .assertIn (device , self . gather_benchmark_configs .DEVICE_POOLS )
124192
125193    def  test_gather_benchmark_configs_cli (self ):
126194        args  =  {
0 commit comments