Skip to content

Commit ab6e6cf

Browse files
committed
final
Signed-off-by: Yang Wang <[email protected]>
1 parent 4aced24 commit ab6e6cf

File tree

1 file changed

+188
-20
lines changed

1 file changed

+188
-20
lines changed

.ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py

Lines changed: 188 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
from dataclasses import dataclass
66
from datetime import datetime
7+
from sre_compile import REPEAT_ONE
78
from typing import Any, Dict, List, Optional, Tuple
89

910
import pandas as pd
@@ -139,6 +140,187 @@ def argparser():
139140
return parser.parse_args()
140141

141142

143+
class BenchmarkFetcher:
144+
def __init__(
145+
self,
146+
env="prod",
147+
repo="pytorch/pytorch",
148+
benchmark_name="",
149+
disable_logging=False,
150+
group_table_fields=None,
151+
group_row_fields=None,
152+
processor_funcs=None,
153+
):
154+
"""
155+
Initialize the ExecutorchBenchmarkFetcher.
156+
157+
Args:
158+
env: Environment to use ("local" or "prod")
159+
disable_logging: Whether to suppress log output
160+
group_table_fields: Custom fields to group tables by (defaults to device, backend, arch, model)
161+
group_row_fields: Custom fields to group rows by (defaults to workflow_id, job_id, granularity_bucket)
162+
"""
163+
self.env = env
164+
self.base_url = self._get_base_url()
165+
self.query_group_table_by_fields = (
166+
group_table_fields
167+
if group_table_fields
168+
else ["device", "backend", "arch", "model"]
169+
)
170+
self.query_group_row_by_fields = (
171+
group_row_fields
172+
if group_row_fields
173+
else ["workflow_id", "job_id", "granularity_bucket"]
174+
)
175+
self.data = None
176+
self.disable_logging = disable_logging
177+
self.processor_funcs = processor_funcs if processor_funcs else []
178+
179+
self.repo = repo
180+
self.benchmark_name = benchmark_name
181+
182+
def _get_base_url(self) -> str:
183+
"""
184+
Get the base URL for API requests based on environment.
185+
186+
Returns:
187+
Base URL string for the configured environment
188+
"""
189+
base_urls = {
190+
"local": "http://localhost:3000",
191+
"prod": "https://hud.pytorch.org",
192+
}
193+
return base_urls[self.env]
194+
195+
def _fetch_data(
196+
self, start_time: str, end_time: str
197+
) -> Optional[List[Dict[str, Any]]]:
198+
"""
199+
Fetch and process benchmark data for the specified time range.
200+
201+
Args:
202+
start_time: ISO8601 formatted start time
203+
end_time: ISO8601 formatted end time
204+
205+
Returns:
206+
Processed benchmark data or None if fetch failed
207+
"""
208+
data = self._fetch_data(start_time, end_time)
209+
if data is None:
210+
return None
211+
self.data = self._process(data)
212+
return self.data
213+
214+
def to_df(self) -> Any:
215+
if not self.data:
216+
return
217+
dfs = [
218+
{"groupInfo": item["groupInfo"], "df": pd.DataFrame(item["rows"])}
219+
for item in self.data
220+
]
221+
return dfs
222+
223+
def _process(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
224+
"""
225+
Process raw benchmark data.
226+
227+
This method:
228+
1. Normalizes string values in groupInfo
229+
2. Creates table_name from group info components
230+
3. Determines aws_type (public/private) based on device name
231+
4. Sorts results by table_name
232+
233+
Args:
234+
data: Raw benchmark data from API
235+
236+
Returns:
237+
Processed benchmark data
238+
"""
239+
for item in data:
240+
# normalized string values in groupInfo
241+
item["groupInfo"] = {
242+
k: self.normalize_string(v)
243+
for k, v in item.get("groupInfo", {}).items()
244+
if v is not None and isinstance(v, str)
245+
}
246+
group = item.get("groupInfo", {})
247+
name = self._generate_table_name(group, self.query_group_table_by_fields)
248+
249+
# Add full name joined by the group key fields
250+
item["table_name"] = name
251+
252+
# Mark aws_type: private or public
253+
if group.get("device", "").find("private") != -1:
254+
item["groupInfo"]["aws_type"] = "private"
255+
else:
256+
item["groupInfo"]["aws_type"] = "public"
257+
258+
data.sort(key=lambda x: x["table_name"])
259+
logging.info(f"fetched {len(data)} table views")
260+
return data
261+
262+
def normalize_string(self, s, replace="_"):
263+
return s.lower().replace(" ", replace)
264+
265+
def _generate_table_name(self, group_info: dict, fields: list[str]) -> str:
266+
name = "|".join(
267+
group_info[k] for k in fields if k in group_info and group_info[k]
268+
)
269+
return self.normalize_string(name)
270+
271+
def _call_api(self, start_time, end_time):
272+
url = f"{self.base_url}/api/benchmark/group_data"
273+
params_object = BenchmarkQueryGroupDataParams(
274+
repo="pytorch/executorch",
275+
benchmark_name="ExecuTorch",
276+
start_time=start_time,
277+
end_time=end_time,
278+
group_table_by_fields=self.query_group_table_by_fields,
279+
group_row_by_fields=self.query_group_row_by_fields,
280+
)
281+
params = {k: v for k, v in params_object.__dict__.items() if v is not None}
282+
response = requests.get(url, params=params)
283+
if response.status_code == 200:
284+
return response.json()
285+
else:
286+
logging.info(f"Failed to fetch benchmark data ({response.status_code})")
287+
logging.info(response.text)
288+
return None
289+
290+
def _write_multiple_csv_files(
291+
self, data_list: List[Dict[str, Any]], output_dir: str, file_prefix=""
292+
) -> None:
293+
"""
294+
Write multiple benchmark results to separate CSV files.
295+
296+
Each entry in `data_list` becomes its own CSV file.
297+
298+
Args:
299+
data_list: List of benchmark result dictionaries
300+
output_dir: Directory to save the CSV files
301+
"""
302+
os.makedirs(output_dir, exist_ok=True)
303+
logging.info(
304+
f"\n ========= Generating multiple CSV files in {output_dir} ========= \n"
305+
)
306+
for idx, entry in enumerate(data_list):
307+
file_name = entry.get("short_name", f"file{idx+1}")
308+
309+
if file_prefix:
310+
file_name = file_prefix + file_name
311+
if len(file_name) > 100:
312+
logging.warning(
313+
f"File name '{file_name}' is too long, truncating to 100 characters"
314+
)
315+
file_name = file_name[:100]
316+
file_path = os.path.join(output_dir, f"{file_name}.csv")
317+
318+
rows = entry.get("rows", [])
319+
logging.info(f"Writing CSV: {file_path} with {len(rows)} rows")
320+
df = pd.DataFrame(rows)
321+
df.to_csv(file_path, index=False)
322+
323+
142324
class ExecutorchBenchmarkFetcher:
143325
"""
144326
Fetch and process benchmark data from HUD API for ExecutorchBenchmark.
@@ -214,7 +396,7 @@ def run(
214396

215397
if not self.disable_logging:
216398
logging.info(
217-
f"\n ========= Search tables specific for matching keywords ========= \n"
399+
"\n ========= Search tables specific for matching keywords ========= \n"
218400
)
219401
self.results_private = self.find_target_tables(privateDeviceMatchings, True)
220402
self.results_public = self.find_target_tables(publicDeviceMatchings, False)
@@ -282,7 +464,7 @@ def _write_multi_sheet_excel(
282464
output_path: Path to save the Excel file
283465
"""
284466
logging.info(
285-
f"\n ========= Generate excel file with multiple sheets for {output_path}========= \n"
467+
"\n ========= Generate excel file with multiple sheets for {output_path}========= \n"
286468
)
287469
with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer:
288470
for idx, entry in enumerate(data_list):
@@ -403,29 +585,15 @@ def print_all_names(self) -> None:
403585
return
404586
logging.info("peeking table result:")
405587
logging.info(json.dumps(self.data[0], indent=2))
406-
public_ones = [
407-
item["table_name"]
408-
for item in self.data
409-
if item["groupInfo"]["aws_type"] == "public"
410-
]
411-
private_ones = [
412-
item["table_name"]
413-
for item in self.data
414-
if item["groupInfo"]["aws_type"] == "private"
415-
]
588+
names = [item["table_name"] for item in self.data]
416589
# Print all found benchmark table names
417590
logging.info(
418-
f"\n============List all benchmark result table names (Public and Private) below =================\n"
591+
"\n============List all benchmark result table names =================\n"
419592
)
420593
logging.info(
421-
f"\n============ public device benchmark results({len(public_ones)})=================\n"
422-
)
423-
for name in public_ones:
424-
logging.info(name)
425-
logging.info(
426-
f"\n======= private device benchmark results({len(private_ones)})=======\n"
594+
f"\n============ public device benchmark results({len(names)})=================\n"
427595
)
428-
for name in private_ones:
596+
for name in names:
429597
logging.info(name)
430598

431599
def _generate_table_name(self, group_info: dict, fields: list[str]) -> str:

0 commit comments

Comments
 (0)