@@ -192,20 +192,28 @@ def test_set_output_no_github_env(self, mock_getenv, mock_file):
192192
193193 def test_device_pools_contains_all_devices (self ):
194194 expected_devices = [
195- "apple_iphone_15" ,
196- "apple_iphone_15+ios_18 " ,
197- "samsung_galaxy_s22" ,
198- "samsung_galaxy_s24" ,
199- "google_pixel_8_pro " ,
195+ "apple_iphone_15+public " ,
196+ "apple_iphone_15+ios_18_public " ,
197+ "samsung_galaxy_s22+public " ,
198+ "samsung_galaxy_s24+ultra_private " ,
199+ "google_pixel_8+pro_public " ,
200200 ]
201201 for device in expected_devices :
202- self .assertIn (device , self .gather_benchmark_configs .DEVICE_POOLS )
202+ m = re .match (self .gather_benchmark_configs .DEVICE_POOLS_REGEX , device )
203+
204+ device_name = m .group ("device_name" )
205+ variant = m .group ("variant" )
206+
207+ self .assertIn (device_name , self .gather_benchmark_configs .DEVICE_POOLS )
208+ self .assertIn (
209+ variant , self .gather_benchmark_configs .DEVICE_POOLS [device_name ]
210+ )
203211
204212 def test_gather_benchmark_configs_cli (self ):
205213 args = {
206214 "models" : "mv2,dl3" ,
207215 "os" : "ios" ,
208- "devices" : "apple_iphone_15" ,
216+ "devices" : "apple_iphone_15+pro_private " ,
209217 "configs" : None ,
210218 }
211219
@@ -223,11 +231,29 @@ def test_gather_benchmark_configs_cli(self):
223231 self .assertIn ('"config": "xnnpack_q8"' , result .stdout )
224232 self .assertIn ('"config": "mps"' , result .stdout )
225233
226- def test_gather_benchmark_configs_cli_specified_configs (self ):
234+ def test_gather_benchmark_configs_cli_invalid_device (self ):
227235 args = {
228236 "models" : "mv2,dl3" ,
229237 "os" : "ios" ,
230238 "devices" : "apple_iphone_15" ,
239+ "configs" : None ,
240+ }
241+
242+ cmd = ["python" , ".ci/scripts/gather_benchmark_configs.py" ]
243+ for key , value in args .items ():
244+ if value is not None :
245+ cmd .append (f"--{ key } " )
246+ cmd .append (value )
247+
248+ result = subprocess .run (cmd , capture_output = True , text = True )
249+ self .assertEqual (result .returncode , 0 , f"Error: { result .stderr } " )
250+ self .assertIn ('{"include": []}' , result .stdout )
251+
252+ def test_gather_benchmark_configs_cli_specified_configs (self ):
253+ args = {
254+ "models" : "mv2,dl3" ,
255+ "os" : "ios" ,
256+ "devices" : "apple_iphone_15+private" ,
231257 "configs" : "coreml_fp16,xnnpack_q8" ,
232258 }
233259
@@ -249,7 +275,7 @@ def test_gather_benchmark_configs_cli_specified_configs_raise(self):
249275 args = {
250276 "models" : "mv2,dl3" ,
251277 "os" : "ios" ,
252- "devices" : "apple_iphone_15" ,
278+ "devices" : "apple_iphone_15+public " ,
253279 "configs" : "qnn_q8" ,
254280 }
255281
0 commit comments