Skip to content

Commit d22af04

Browse files
committed
fix error test
Signed-off-by: Yang Wang <[email protected]>
1 parent 95b30a4 commit d22af04

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

.ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class MatchingGroupResult:
7272
Container for benchmark results grouped by category.
7373
7474
Attributes:
75-
category: Category name (e.g., "private", "public")
75+
category: Category name (e.g., 'private', 'public')
7676
data: List of benchmark data for this category
7777
"""
7878

@@ -135,7 +135,7 @@ def __init__(
135135
Initialize the ExecutorchBenchmarkFetcher.
136136
137137
Args:
138-
env: Environment to use ("local" or "prod")
138+
env: Environment to use ('local' or 'prod')
139139
disable_logging: Whether to suppress log output
140140
group_table_fields: Custom fields to group tables by (defaults to device, backend, arch, model)
141141
group_row_fields: Custom fields to group rows by (defaults to workflow_id, job_id, granularity_bucket)
@@ -163,7 +163,7 @@ def run(
163163
self,
164164
start_time: str,
165165
end_time: str,
166-
filters: BenchmarkFilters,
166+
filters: Optional[BenchmarkFilters] = None,
167167
) -> None:
168168
# reset group & raw data for new run
169169
self.matching_groups = {}
@@ -398,7 +398,7 @@ def to_df(self) -> Dict[str, List[Dict[str, Union[Dict[str, Any], pd.DataFrame]]
398398
Each DataFrame represents one benchmark configuration.
399399
400400
Returns:
401-
Dictionary mapping categories ['private','public'] to lists of DataFrames "df" with metadata "groupInfo".
401+
Dictionary mapping categories ['private','public'] to lists of DataFrames "df" with metadata 'groupInfo'.
402402
403403
"""
404404
result = {}
@@ -476,11 +476,7 @@ def _get_base_url(self) -> str:
476476
Returns:
477477
Base URL string for the configured environment
478478
"""
479-
base_urls = {
480-
"local": "http://localhost:3000",
481-
"prod": "https://hud.pytorch.org",
482-
}
483-
return base_urls[self.env]
479+
return BASE_URLS[self.env]
484480

485481
def print_all_table_info(self) -> None:
486482
"""
@@ -509,7 +505,9 @@ def print_all_table_info(self) -> None:
509505
for name in names:
510506
logging.info(json.dumps(name, indent=2))
511507

512-
def _generate_table_name(self, group_info: Dict[str, Any], fields: List[str]) -> str:
508+
def _generate_table_name(
509+
self, group_info: Dict[str, Any], fields: List[str]
510+
) -> str:
513511
"""
514512
Generate a table name from group info fields.
515513
@@ -531,7 +529,6 @@ def _generate_table_name(self, group_info: Dict[str, Any], fields: List[str]) ->
531529
name = name.replace("(private)", "")
532530
return name
533531

534-
535532
def _generate_matching_name(self, group_info: dict, fields: list[str]) -> str:
536533
info = deepcopy(group_info)
537534
name = "_".join(
@@ -543,7 +540,7 @@ def _generate_matching_name(self, group_info: dict, fields: list[str]) -> str:
543540
return name
544541

545542
def _process(
546-
self, input_data: List[Dict[str, Any]], filters: BenchmarkFilters
543+
self, input_data: List[Dict[str, Any]], filters: Optional[BenchmarkFilters]
547544
) -> Dict[str, Any]:
548545
"""
549546
Process raw benchmark data.
@@ -576,7 +573,6 @@ def _process(
576573
item["table_name"] = self._generate_table_name(
577574
group, self.query_group_table_by_fields
578575
)
579-
580576
# Mark aws_type: private or public
581577
if group.get("device", "").find("private") != -1:
582578
item["info"]["aws_type"] = "private"
@@ -586,7 +582,8 @@ def _process(
586582
raw_data = deepcopy(data)
587583

588584
# applies customized filters if any
589-
data = self.filter_results(data, filters)
585+
if filters:
586+
data = self.filter_results(data, filters)
590587
# generate private and public results
591588
private = sorted(
592589
(
@@ -651,7 +648,9 @@ def normalize_string(self, s: str) -> str:
651648
s = s.replace(")-", ")").replace("-)", ")")
652649
return s
653650

654-
def filter_results(self, data: List[Dict[str, Any]], filters: BenchmarkFilters) -> List[Dict[str, Any]]:
651+
def filter_results(
652+
self, data: List[Dict[str, Any]], filters: BenchmarkFilters
653+
) -> List[Dict[str, Any]]:
655654
"""
656655
Filter benchmark results based on specified criteria.
657656

0 commit comments

Comments
 (0)