Skip to content

Commit 2e76fc5

Browse files
committed
Fix unit test
Signed-off-by: Huy Do <[email protected]>
1 parent 196ca8e commit 2e76fc5

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

.ci/scripts/tests/test_gather_benchmark_configs.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,20 +192,28 @@ def test_set_output_no_github_env(self, mock_getenv, mock_file):
192192

193193
def test_device_pools_contains_all_devices(self):
194194
expected_devices = [
195-
"apple_iphone_15",
196-
"apple_iphone_15+ios_18",
197-
"samsung_galaxy_s22",
198-
"samsung_galaxy_s24",
199-
"google_pixel_8_pro",
195+
"apple_iphone_15+public",
196+
"apple_iphone_15+ios_18_public",
197+
"samsung_galaxy_s22+public",
198+
"samsung_galaxy_s24+ultra_private",
199+
"google_pixel_8+pro_public",
200200
]
201201
for device in expected_devices:
202-
self.assertIn(device, self.gather_benchmark_configs.DEVICE_POOLS)
202+
m = re.match(self.gather_benchmark_configs.DEVICE_POOLS_REGEX, device)
203+
204+
device_name = m.group("device_name")
205+
variant = m.group("variant")
206+
207+
self.assertIn(device_name, self.gather_benchmark_configs.DEVICE_POOLS)
208+
self.assertIn(
209+
variant, self.gather_benchmark_configs.DEVICE_POOLS[device_name]
210+
)
203211

204212
def test_gather_benchmark_configs_cli(self):
205213
args = {
206214
"models": "mv2,dl3",
207215
"os": "ios",
208-
"devices": "apple_iphone_15",
216+
"devices": "apple_iphone_15+pro_private",
209217
"configs": None,
210218
}
211219

@@ -223,11 +231,29 @@ def test_gather_benchmark_configs_cli(self):
223231
self.assertIn('"config": "xnnpack_q8"', result.stdout)
224232
self.assertIn('"config": "mps"', result.stdout)
225233

226-
def test_gather_benchmark_configs_cli_specified_configs(self):
234+
def test_gather_benchmark_configs_cli_invalid_device(self):
227235
args = {
228236
"models": "mv2,dl3",
229237
"os": "ios",
230238
"devices": "apple_iphone_15",
239+
"configs": None,
240+
}
241+
242+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
243+
for key, value in args.items():
244+
if value is not None:
245+
cmd.append(f"--{key}")
246+
cmd.append(value)
247+
248+
result = subprocess.run(cmd, capture_output=True, text=True)
249+
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
250+
self.assertIn('{"include": []}', result.stdout)
251+
252+
def test_gather_benchmark_configs_cli_specified_configs(self):
253+
args = {
254+
"models": "mv2,dl3",
255+
"os": "ios",
256+
"devices": "apple_iphone_15+private",
231257
"configs": "coreml_fp16,xnnpack_q8",
232258
}
233259

@@ -249,7 +275,7 @@ def test_gather_benchmark_configs_cli_specified_configs_raise(self):
249275
args = {
250276
"models": "mv2,dl3",
251277
"os": "ios",
252-
"devices": "apple_iphone_15",
278+
"devices": "apple_iphone_15+public",
253279
"configs": "qnn_q8",
254280
}
255281

0 commit comments

Comments
 (0)