Skip to content

Commit 04dbd97

Browse files
committed
final
Signed-off-by: Yang Wang <[email protected]>
1 parent 2878e96 commit 04dbd97

File tree

3 files changed

+59
-16
lines changed

3 files changed

+59
-16
lines changed

.ci/scripts/benchmark_tooling/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ python .ci/scripts/benchmark_tooling/analyze_benchmark_stability.py \
119119

120120
## Running Unit Tests
121121

122-
The benchmark tooling includes comprehensive unit tests to ensure functionality.
122+
The benchmark tooling includes unit tests to ensure functionality.
123123

124124
### Using pytest
125125

.ci/scripts/benchmark_tooling/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import json
22
import os
3+
from typing import Any, Dict, List
34

45
import pandas as pd
56

67

7-
def read_excel_with_json_header(path: str):
8+
def read_excel_with_json_header(path: str) -> List[Dict[str, Any]]:
89
# Read all sheets into a dict of DataFrames, without altering
910
all_sheets = pd.read_excel(path, sheet_name=None, header=None, engine="openpyxl")
1011

@@ -21,7 +22,7 @@ def read_excel_with_json_header(path: str):
2122
return results
2223

2324

24-
def read_all_csv_with_metadata(folder_path: str):
25+
def read_all_csv_with_metadata(folder_path: str) -> List[Dict[str, Any]]:
2526
results = [] # {filename: {"meta": dict, "df": DataFrame}}
2627
for fname in os.listdir(folder_path):
2728
if not fname.lower().endswith(".csv"):

.ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dataclasses import dataclass
1717
from datetime import datetime
1818
from enum import Enum
19-
from typing import Any, Dict, List
19+
from typing import Any, Dict, List, Optional, Union
2020

2121
import pandas as pd
2222
import requests
@@ -126,8 +126,8 @@ class ExecutorchBenchmarkFetcher:
126126

127127
def __init__(
128128
self,
129-
env="prod",
130-
disable_logging=False,
129+
env: str = "prod",
130+
disable_logging: bool = False,
131131
group_table_fields=None,
132132
group_row_fields=None,
133133
):
@@ -189,7 +189,13 @@ def _filter_out_failure_only(
189189
self, data_list: List[Dict[str, Any]]
190190
) -> List[Dict[str, Any]]:
191191
"""
192-
clean FAILURE_REPORT only metrics
192+
Clean data by removing rows that only contain FAILURE_REPORT metrics.
193+
194+
Args:
195+
data_list: List of benchmark data dictionaries
196+
197+
Returns:
198+
Filtered list with rows containing only FAILURE_REPORT removed
193199
"""
194200
ONLY = {"workflow_id", "granularity_bucket", "job_id", "FAILURE_REPORT"}
195201
for item in data_list:
@@ -230,7 +236,13 @@ def _filter_public_result(self, private_list, public_list):
230236
filtered_public = [item for item in public_list if item["table_name"] in common]
231237
return filtered_public
232238

233-
def get_result(self):
239+
def get_result(self) -> Dict[str, List[Dict[str, Any]]]:
240+
"""
241+
Get a deep copy of the benchmark results.
242+
243+
Returns:
244+
Dictionary containing benchmark results grouped by category
245+
"""
234246
return deepcopy(self.to_dict())
235247

236248
def to_excel(self, output_dir: str = ".") -> None:
@@ -270,7 +282,7 @@ def _write_multi_sheet_excel(self, data_list, output_dir, file_name):
270282
worksheet.write_string(0, 0, json_str)
271283

272284
logging.info(
273-
f"Wrting excel sheet to file {file} with sheet name {sheet_name} for {entry["table_name"]}"
285+
f"Wrting excel sheet to file {file} with sheet name {sheet_name} for {entry['table_name']}"
274286
)
275287
# Write DataFrame starting at row 2 (index 1)
276288
df.to_excel(writer, sheet_name=sheet_name, startrow=1, index=False)
@@ -366,7 +378,7 @@ def generate_json_file(self, data, file_name, output_dir: str = "."):
366378
json.dump(data, f, indent=2)
367379
return path
368380

369-
def to_dict(self) -> Any:
381+
def to_dict(self) -> Dict[str, List[Dict[str, Any]]]:
370382
"""
371383
Convert benchmark results to a dictionary.
372384
@@ -378,15 +390,16 @@ def to_dict(self) -> Any:
378390
result[item.category] = item.data
379391
return result
380392

381-
def to_df(self) -> Any:
393+
def to_df(self) -> Dict[str, List[Dict[str, Union[Dict[str, Any], pd.DataFrame]]]]:
382394
"""
383395
Convert benchmark results to pandas DataFrames.
384396
385397
Creates a dictionary with categories as keys and lists of DataFrames as values.
386398
Each DataFrame represents one benchmark configuration.
387399
388400
Returns:
389-
Dictionary mapping categories to lists of DataFrames with metadata
401+
Dictionary mapping categories ['private','public'] to lists of DataFrames "df" with metadata "groupInfo".
402+
390403
"""
391404
result = {}
392405
for item in self.matching_groups.values():
@@ -423,7 +436,20 @@ def to_csv(self, output_dir: str = ".") -> None:
423436
path = os.path.join(output_dir, item.category)
424437
self._write_multiple_csv_files(item.data, path)
425438

426-
def _write_multiple_csv_files(self, data_list, output_dir, prefix=""):
439+
def _write_multiple_csv_files(
440+
self, data_list: List[Dict[str, Any]], output_dir: str, prefix: str = ""
441+
) -> None:
442+
"""
443+
Write multiple benchmark results to CSV files.
444+
445+
Creates a CSV file for each benchmark configuration, with metadata
446+
as a JSON string in the first row and data in subsequent rows.
447+
448+
Args:
449+
data_list: List of benchmark result dictionaries
450+
output_dir: Directory to save CSV files
451+
prefix: Optional prefix for CSV filenames
452+
"""
427453
os.makedirs(output_dir, exist_ok=True)
428454
for idx, entry in enumerate(data_list):
429455
filename = f"{prefix}_table{idx+1}.csv" if prefix else f"table{idx+1}.csv"
@@ -506,7 +532,9 @@ def _generate_matching_name(self, group_info: dict, fields: list[str]) -> str:
506532
# name = name +'(private)'
507533
return name
508534

509-
def _process(self, input_data: List[Dict[str, Any]], filters: BenchmarkFilters):
535+
def _process(
536+
self, input_data: List[Dict[str, Any]], filters: BenchmarkFilters
537+
) -> Dict[str, Any]:
510538
"""
511539
Process raw benchmark data.
512540
@@ -578,7 +606,9 @@ def _clean_data(self, data_list):
578606
data = self._filter_out_failure_only(removed_gen_arch)
579607
return data
580608

581-
def _fetch_execu_torch_data(self, start_time, end_time):
609+
def _fetch_execu_torch_data(
610+
self, start_time: str, end_time: str
611+
) -> Optional[List[Dict[str, Any]]]:
582612
url = f"{self.base_url}/api/benchmark/group_data"
583613
params_object = BenchmarkQueryGroupDataParams(
584614
repo="pytorch/executorch",
@@ -611,7 +641,19 @@ def normalize_string(self, s: str) -> str:
611641
s = s.replace(")-", ")").replace("-)", ")")
612642
return s
613643

614-
def filter_results(self, data: List, filters: BenchmarkFilters):
644+
def filter_results(self, data: List[Dict[str, Any]], filters: BenchmarkFilters) -> List[Dict[str, Any]]:
645+
"""
646+
Filter benchmark results based on specified criteria.
647+
648+
Applies OR logic for filtering - results match if they match any of the specified filters.
649+
650+
Args:
651+
data: List of benchmark data dictionaries
652+
filters: BenchmarkFilters object containing filter criteria
653+
654+
Returns:
655+
Filtered list of benchmark data dictionaries
656+
"""
615657
backends = filters.backends
616658
devices = filters.devices
617659
models = filters.models

0 commit comments

Comments
 (0)