Skip to content

Commit 25769da

Browse files
committed
fix error test
Signed-off-by: Yang Wang <[email protected]>
1 parent d22af04 commit 25769da

File tree

3 files changed

+189
-30
lines changed

3 files changed

+189
-30
lines changed

.ci/scripts/benchmark_tooling/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ python3 .ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py \
4444

4545
##### Filtering Options:
4646
Notice, the filter needs full name matchings with correct format, to see all the options of the filter choices, please run the script with `--print-all-table-info`, and pay attention to section `Full list of table info from HUD API` with the field 'info', which contains normalized data we use to filter records from the original metadata 'groupInfo'.
47+
The filter block any record if it does not in any of the filter keywords.
4748

4849
- `--devices`: Filter by specific device names (e.g., "samsung-galaxy-s22-5g", "samsung-galaxy-s22plus-5g")
4950
- `--backends`: Filter by specific backend names (e.g., "qnn-q8" , ""llama3-spinquan)

.ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -525,18 +525,7 @@ def _generate_table_name(
525525
for k in fields
526526
if k in group_info and group_info[k]
527527
)
528-
if "(private)" in name:
529-
name = name.replace("(private)", "")
530-
return name
531528

532-
def _generate_matching_name(self, group_info: dict, fields: list[str]) -> str:
533-
info = deepcopy(group_info)
534-
name = "_".join(
535-
self.normalize_string(info[k]) for k in fields if k in info and info[k]
536-
)
537-
if "(private)" in name:
538-
name = name.replace("(private)", "")
539-
# name = name +'(private)'
540529
return name
541530

542531
def _process(
@@ -562,23 +551,28 @@ def _process(
562551
public = []
563552

564553
for item in data:
565-
# normalized string values groupInfo to info
566-
item["info"] = {
567-
k: self.normalize_string(v)
568-
for k, v in item.get("groupInfo", {}).items()
569-
if v is not None and isinstance(v, str)
570-
}
554+
org_group = item.get("groupInfo", {})
555+
if "info" not in item:
556+
item["info"] = {}
557+
if org_group.get("device", "").find("private") != -1:
558+
item["info"]["aws_type"] = "private"
559+
else:
560+
item["info"]["aws_type"] = "public"
561+
public.append(item)
562+
563+
# Merge normalized groupInfo string values into item["info"]
564+
item["info"].update(
565+
{
566+
k: self.normalize_string(v)
567+
for k, v in item.get("groupInfo", {}).items()
568+
if v is not None and isinstance(v, str)
569+
}
570+
)
571571
group = item.get("info", {})
572572
# Add full name joined by the group key fields
573573
item["table_name"] = self._generate_table_name(
574574
group, self.query_group_table_by_fields
575575
)
576-
# Mark aws_type: private or public
577-
if group.get("device", "").find("private") != -1:
578-
item["info"]["aws_type"] = "private"
579-
else:
580-
item["info"]["aws_type"] = "public"
581-
public.append(item)
582576
raw_data = deepcopy(data)
583577

584578
# applies customized filters if any
@@ -646,6 +640,7 @@ def normalize_string(self, s: str) -> str:
646640
s = re.sub(r"-{2,}", "-", s)
647641
s = s.replace("-(", "(").replace("(-", "(")
648642
s = s.replace(")-", ")").replace("-)", ")")
643+
s = s.replace("(private)", "")
649644
return s
650645

651646
def filter_results(
@@ -678,7 +673,7 @@ def filter_results(
678673
info = item.get("info", {})
679674
if backends and info.get("backend") not in backends:
680675
continue
681-
if devices and not any(dev in info.get("device", "") for dev in devices):
676+
if devices and info.get("device", "") not in devices:
682677
continue
683678
if models and info.get("model", "") not in models:
684679
continue
@@ -688,7 +683,7 @@ def filter_results(
688683
if after_len == 0:
689684
logging.info(
690685
"it seems like there is no result matches the filter values"
691-
", please run script --no-silent again, and search for values in field"
686+
", please run script --list-all-table-info again, and search for values in field"
692687
" 'info' for right format"
693688
)
694689
return results
@@ -742,15 +737,17 @@ def argparsers():
742737
parser.add_argument(
743738
"--backends",
744739
nargs="+",
745-
help="Filter results by one or more backend full name(e.g. --backend qlora mv3) (OR logic)",
740+
help="Filter results by one or more backend full name(e.g. --backend qlora mv3) (OR logic within backends scope, AND logic with other filter type)",
746741
)
747742
parser.add_argument(
748743
"--devices",
749744
nargs="+",
750-
help="Filter results by one or more device names (e.g. --devices samsung-galaxy-s22-5g)(OR logic)",
745+
help="Filter results by one or more device names (e.g. --devices samsung-galaxy-s22-5g)(OR logic within devices, AND logic with other filter type)",
751746
)
752747
parser.add_argument(
753-
"--models", nargs="+", help="Filter by one or more models (OR logic)"
748+
"--models",
749+
nargs="+",
750+
help="Filter by one or more models (OR logic withn models scope, AND logic with other filter type)",
754751
)
755752
return parser.parse_args()
756753

.ci/scripts/tests/test_get_benchmark_analysis_data.py

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def test_normalize_string(self):
206206
("test_string", "test-string"),
207207
("test string", "test-string"),
208208
("test--string", "test-string"),
209-
("test (private)", "test(private)"),
209+
("test (private)", "test"),
210210
("test@#$%^&*", "test-"),
211211
]
212212

@@ -335,6 +335,97 @@ def test_filter_public_result(self):
335335
result = self.fetcher._filter_public_result(private_list, public_list)
336336
self.assertEqual(result, expected)
337337

338+
def test_filter_results(self):
339+
"""Test filter_results method with various filter combinations."""
340+
# Create test data
341+
test_data = [
342+
{
343+
"info": {
344+
"model": "llama3",
345+
"backend": "qlora",
346+
"device": "iphone-15-pro-max",
347+
"arch": "ios-17",
348+
},
349+
"rows": [{"metric_1": 1.0}],
350+
},
351+
{
352+
"info": {
353+
"model": "llama3",
354+
"backend": "spinquant",
355+
"device": "iphone-15-pro-max",
356+
"arch": "ios-17",
357+
},
358+
"rows": [{"metric_1": 2.0}],
359+
},
360+
{
361+
"info": {
362+
"model": "mv3",
363+
"backend": "xnnpack-q8",
364+
"device": "samsung-galaxy-s22-5g",
365+
"arch": "android-13",
366+
},
367+
"rows": [{"metric_1": 3.0}],
368+
},
369+
{
370+
"info": {
371+
"model": "mv3",
372+
"backend": "qnn-q8",
373+
"device": "samsung-galaxy-s22-5g",
374+
"arch": "android-13",
375+
},
376+
"rows": [{"metric_1": 4.0}],
377+
},
378+
]
379+
380+
# Test with no filters
381+
empty_filters = self.module.BenchmarkFilters(
382+
models=None, backends=None, devices=None
383+
)
384+
result = self.fetcher.filter_results(test_data, empty_filters)
385+
self.assertEqual(result, test_data)
386+
387+
# Test with model filter
388+
model_filters = self.module.BenchmarkFilters(
389+
models=["llama3"], backends=None, devices=None
390+
)
391+
result = self.fetcher.filter_results(test_data, model_filters)
392+
self.assertEqual(len(result), 2)
393+
self.assertTrue(all(item["info"]["model"] == "llama3" for item in result))
394+
395+
# Test with backend filter
396+
backend_filters = self.module.BenchmarkFilters(
397+
models=None, backends=["qlora", "qnn-q8"], devices=None
398+
)
399+
result = self.fetcher.filter_results(test_data, backend_filters)
400+
self.assertEqual(len(result), 2)
401+
self.assertTrue(
402+
all(item["info"]["backend"] in ["qlora", "qnn-q8"] for item in result)
403+
)
404+
405+
# Test with device filter
406+
device_filters = self.module.BenchmarkFilters(
407+
models=None, backends=None, devices=["samsung-galaxy-s22-5g"]
408+
)
409+
result = self.fetcher.filter_results(test_data, device_filters)
410+
self.assertEqual(len(result), 2)
411+
self.assertTrue(
412+
all("samsung-galaxy-s22-5g" in item["info"]["device"] for item in result)
413+
)
414+
415+
# Test with combined filters (And logic fails)
416+
combined_filters = self.module.BenchmarkFilters(
417+
models=["llama3"], backends=["xnnpack-q8"], devices=None
418+
)
419+
result = self.fetcher.filter_results(test_data, combined_filters)
420+
self.assertEqual(len(result), 0)
421+
422+
# Test with combined filters (And logic success)
423+
combined_filters = self.module.BenchmarkFilters(
424+
models=["llama3"], backends=None, devices=["iphone-15-pro-max"]
425+
)
426+
result = self.fetcher.filter_results(test_data, combined_filters)
427+
self.assertEqual(len(result), 2)
428+
338429
@patch(
339430
"get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data"
340431
)
@@ -442,7 +533,7 @@ def test_run_with_failure_report(self, mock_fetch):
442533
"arch": "ios-17.4.3",
443534
"aws_type": "private",
444535
"backend": "qlora",
445-
"device": "iphone-15-pro-max(private)",
536+
"device": "iphone-15-pro-max",
446537
"model": "llama3",
447538
},
448539
"rows": [
@@ -475,6 +566,76 @@ def test_run_no_data(self, mock_fetch):
475566
self.assertEqual(self.fetcher.matching_groups, {})
476567
mock_fetch.assert_called_once_with("2025-06-01T00:00:00", "2025-06-02T00:00:00")
477568

569+
@patch(
570+
"get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data"
571+
)
572+
def test_run_with_filters(self, mock_fetch):
573+
"""Test run method with filters."""
574+
# Setup mock data
575+
mock_data = [
576+
{
577+
"groupInfo": {
578+
"model": "llama3",
579+
"backend": "qlora",
580+
"device": "Iphone 15 pro max (private)",
581+
"arch": "ios_17",
582+
},
583+
"rows": [{"metric_1": 1.0}],
584+
},
585+
{
586+
"groupInfo": {
587+
"model": "mv3",
588+
"backend": "xnnpack_q8",
589+
"device": "s22_5g (private)",
590+
"arch": "android_13",
591+
},
592+
"rows": [{"metric_1": 2.0}],
593+
},
594+
{
595+
"groupInfo": {
596+
"model": "mv3",
597+
"backend": "xnnpack_q8",
598+
"device": "s22_5g",
599+
"arch": "android_13",
600+
},
601+
"rows": [{"metric_1": 3.0}],
602+
},
603+
]
604+
mock_fetch.return_value = mock_data
605+
606+
# Create filters for llama3 model only
607+
filters = self.module.BenchmarkFilters(
608+
models=["llama3"], backends=None, devices=None
609+
)
610+
# Run the method with filters
611+
self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00", filters)
612+
result = self.fetcher.get_result()
613+
print("result1", result)
614+
615+
# Verify results - should only have llama3 in private results
616+
self.assertEqual(len(result["private"]), 1)
617+
self.assertEqual(result["private"][0]["info"]["model"], "llama3")
618+
619+
# Public results should be empty since there's no matching table_name
620+
self.assertEqual(result["public"], [])
621+
622+
# Test with backend filter
623+
filters = self.module.BenchmarkFilters(
624+
models=None, backends=["xnnpack-q8"], devices=None
625+
)
626+
self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00", filters)
627+
result = self.fetcher.get_result()
628+
629+
print("result", result)
630+
631+
# Verify results - should only have xnnpack-q8 in private results
632+
self.assertEqual(len(result["private"]), 1)
633+
self.assertEqual(result["private"][0]["info"]["backend"], "xnnpack-q8")
634+
635+
# Public results should have the matching xnnpack-q8 entry
636+
self.assertEqual(len(result["public"]), 1)
637+
self.assertEqual(result["public"][0]["info"]["backend"], "xnnpack-q8")
638+
478639
def test_to_dict(self):
479640
"""Test to_dict method."""
480641
# Setup test data

0 commit comments

Comments
 (0)