|
4 | 4 | import os |
5 | 5 | from dataclasses import dataclass |
6 | 6 | from datetime import datetime |
| 7 | +from sre_compile import REPEAT_ONE |
7 | 8 | from typing import Any, Dict, List, Optional, Tuple |
8 | 9 |
|
9 | 10 | import pandas as pd |
@@ -139,6 +140,187 @@ def argparser(): |
139 | 140 | return parser.parse_args() |
140 | 141 |
|
141 | 142 |
|
| 143 | +class BenchmarkFetcher: |
| 144 | + def __init__( |
| 145 | + self, |
| 146 | + env="prod", |
| 147 | + repo="pytorch/pytorch", |
| 148 | + benchmark_name="", |
| 149 | + disable_logging=False, |
| 150 | + group_table_fields=None, |
| 151 | + group_row_fields=None, |
| 152 | + processor_funcs=None, |
| 153 | + ): |
| 154 | + """ |
| 155 | + Initialize the ExecutorchBenchmarkFetcher. |
| 156 | +
|
| 157 | + Args: |
| 158 | + env: Environment to use ("local" or "prod") |
| 159 | + disable_logging: Whether to suppress log output |
| 160 | + group_table_fields: Custom fields to group tables by (defaults to device, backend, arch, model) |
| 161 | + group_row_fields: Custom fields to group rows by (defaults to workflow_id, job_id, granularity_bucket) |
| 162 | + """ |
| 163 | + self.env = env |
| 164 | + self.base_url = self._get_base_url() |
| 165 | + self.query_group_table_by_fields = ( |
| 166 | + group_table_fields |
| 167 | + if group_table_fields |
| 168 | + else ["device", "backend", "arch", "model"] |
| 169 | + ) |
| 170 | + self.query_group_row_by_fields = ( |
| 171 | + group_row_fields |
| 172 | + if group_row_fields |
| 173 | + else ["workflow_id", "job_id", "granularity_bucket"] |
| 174 | + ) |
| 175 | + self.data = None |
| 176 | + self.disable_logging = disable_logging |
| 177 | + self.processor_funcs = processor_funcs if processor_funcs else [] |
| 178 | + |
| 179 | + self.repo = repo |
| 180 | + self.benchmark_name = benchmark_name |
| 181 | + |
| 182 | + def _get_base_url(self) -> str: |
| 183 | + """ |
| 184 | + Get the base URL for API requests based on environment. |
| 185 | +
|
| 186 | + Returns: |
| 187 | + Base URL string for the configured environment |
| 188 | + """ |
| 189 | + base_urls = { |
| 190 | + "local": "http://localhost:3000", |
| 191 | + "prod": "https://hud.pytorch.org", |
| 192 | + } |
| 193 | + return base_urls[self.env] |
| 194 | + |
| 195 | + def _fetch_data( |
| 196 | + self, start_time: str, end_time: str |
| 197 | + ) -> Optional[List[Dict[str, Any]]]: |
| 198 | + """ |
| 199 | + Fetch and process benchmark data for the specified time range. |
| 200 | +
|
| 201 | + Args: |
| 202 | + start_time: ISO8601 formatted start time |
| 203 | + end_time: ISO8601 formatted end time |
| 204 | +
|
| 205 | + Returns: |
| 206 | + Processed benchmark data or None if fetch failed |
| 207 | + """ |
| 208 | + data = self._fetch_data(start_time, end_time) |
| 209 | + if data is None: |
| 210 | + return None |
| 211 | + self.data = self._process(data) |
| 212 | + return self.data |
| 213 | + |
| 214 | + def to_df(self) -> Any: |
| 215 | + if not self.data: |
| 216 | + return |
| 217 | + dfs = [ |
| 218 | + {"groupInfo": item["groupInfo"], "df": pd.DataFrame(item["rows"])} |
| 219 | + for item in self.data |
| 220 | + ] |
| 221 | + return dfs |
| 222 | + |
| 223 | + def _process(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| 224 | + """ |
| 225 | + Process raw benchmark data. |
| 226 | +
|
| 227 | + This method: |
| 228 | + 1. Normalizes string values in groupInfo |
| 229 | + 2. Creates table_name from group info components |
| 230 | + 3. Determines aws_type (public/private) based on device name |
| 231 | + 4. Sorts results by table_name |
| 232 | +
|
| 233 | + Args: |
| 234 | + data: Raw benchmark data from API |
| 235 | +
|
| 236 | + Returns: |
| 237 | + Processed benchmark data |
| 238 | + """ |
| 239 | + for item in data: |
| 240 | + # normalized string values in groupInfo |
| 241 | + item["groupInfo"] = { |
| 242 | + k: self.normalize_string(v) |
| 243 | + for k, v in item.get("groupInfo", {}).items() |
| 244 | + if v is not None and isinstance(v, str) |
| 245 | + } |
| 246 | + group = item.get("groupInfo", {}) |
| 247 | + name = self._generate_table_name(group, self.query_group_table_by_fields) |
| 248 | + |
| 249 | + # Add full name joined by the group key fields |
| 250 | + item["table_name"] = name |
| 251 | + |
| 252 | + # Mark aws_type: private or public |
| 253 | + if group.get("device", "").find("private") != -1: |
| 254 | + item["groupInfo"]["aws_type"] = "private" |
| 255 | + else: |
| 256 | + item["groupInfo"]["aws_type"] = "public" |
| 257 | + |
| 258 | + data.sort(key=lambda x: x["table_name"]) |
| 259 | + logging.info(f"fetched {len(data)} table views") |
| 260 | + return data |
| 261 | + |
| 262 | + def normalize_string(self, s, replace="_"): |
| 263 | + return s.lower().replace(" ", replace) |
| 264 | + |
| 265 | + def _generate_table_name(self, group_info: dict, fields: list[str]) -> str: |
| 266 | + name = "|".join( |
| 267 | + group_info[k] for k in fields if k in group_info and group_info[k] |
| 268 | + ) |
| 269 | + return self.normalize_string(name) |
| 270 | + |
| 271 | + def _call_api(self, start_time, end_time): |
| 272 | + url = f"{self.base_url}/api/benchmark/group_data" |
| 273 | + params_object = BenchmarkQueryGroupDataParams( |
| 274 | + repo="pytorch/executorch", |
| 275 | + benchmark_name="ExecuTorch", |
| 276 | + start_time=start_time, |
| 277 | + end_time=end_time, |
| 278 | + group_table_by_fields=self.query_group_table_by_fields, |
| 279 | + group_row_by_fields=self.query_group_row_by_fields, |
| 280 | + ) |
| 281 | + params = {k: v for k, v in params_object.__dict__.items() if v is not None} |
| 282 | + response = requests.get(url, params=params) |
| 283 | + if response.status_code == 200: |
| 284 | + return response.json() |
| 285 | + else: |
| 286 | + logging.info(f"Failed to fetch benchmark data ({response.status_code})") |
| 287 | + logging.info(response.text) |
| 288 | + return None |
| 289 | + |
| 290 | + def _write_multiple_csv_files( |
| 291 | + self, data_list: List[Dict[str, Any]], output_dir: str, file_prefix="" |
| 292 | + ) -> None: |
| 293 | + """ |
| 294 | + Write multiple benchmark results to separate CSV files. |
| 295 | +
|
| 296 | + Each entry in `data_list` becomes its own CSV file. |
| 297 | +
|
| 298 | + Args: |
| 299 | + data_list: List of benchmark result dictionaries |
| 300 | + output_dir: Directory to save the CSV files |
| 301 | + """ |
| 302 | + os.makedirs(output_dir, exist_ok=True) |
| 303 | + logging.info( |
| 304 | + f"\n ========= Generating multiple CSV files in {output_dir} ========= \n" |
| 305 | + ) |
| 306 | + for idx, entry in enumerate(data_list): |
| 307 | + file_name = entry.get("short_name", f"file{idx+1}") |
| 308 | + |
| 309 | + if file_prefix: |
| 310 | + file_name = file_prefix + file_name |
| 311 | + if len(file_name) > 100: |
| 312 | + logging.warning( |
| 313 | + f"File name '{file_name}' is too long, truncating to 100 characters" |
| 314 | + ) |
| 315 | + file_name = file_name[:100] |
| 316 | + file_path = os.path.join(output_dir, f"{file_name}.csv") |
| 317 | + |
| 318 | + rows = entry.get("rows", []) |
| 319 | + logging.info(f"Writing CSV: {file_path} with {len(rows)} rows") |
| 320 | + df = pd.DataFrame(rows) |
| 321 | + df.to_csv(file_path, index=False) |
| 322 | + |
| 323 | + |
142 | 324 | class ExecutorchBenchmarkFetcher: |
143 | 325 | """ |
144 | 326 | Fetch and process benchmark data from HUD API for ExecutorchBenchmark. |
@@ -214,7 +396,7 @@ def run( |
214 | 396 |
|
215 | 397 | if not self.disable_logging: |
216 | 398 | logging.info( |
217 | | - f"\n ========= Search tables specific for matching keywords ========= \n" |
| 399 | + "\n ========= Search tables specific for matching keywords ========= \n" |
218 | 400 | ) |
219 | 401 | self.results_private = self.find_target_tables(privateDeviceMatchings, True) |
220 | 402 | self.results_public = self.find_target_tables(publicDeviceMatchings, False) |
@@ -282,7 +464,7 @@ def _write_multi_sheet_excel( |
282 | 464 | output_path: Path to save the Excel file |
283 | 465 | """ |
284 | 466 | logging.info( |
285 | | - f"\n ========= Generate excel file with multiple sheets for {output_path}========= \n" |
| 467 | + "\n ========= Generate excel file with multiple sheets for {output_path}========= \n" |
286 | 468 | ) |
287 | 469 | with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer: |
288 | 470 | for idx, entry in enumerate(data_list): |
@@ -403,29 +585,15 @@ def print_all_names(self) -> None: |
403 | 585 | return |
404 | 586 | logging.info("peeking table result:") |
405 | 587 | logging.info(json.dumps(self.data[0], indent=2)) |
406 | | - public_ones = [ |
407 | | - item["table_name"] |
408 | | - for item in self.data |
409 | | - if item["groupInfo"]["aws_type"] == "public" |
410 | | - ] |
411 | | - private_ones = [ |
412 | | - item["table_name"] |
413 | | - for item in self.data |
414 | | - if item["groupInfo"]["aws_type"] == "private" |
415 | | - ] |
| 588 | + names = [item["table_name"] for item in self.data] |
416 | 589 | # Print all found benchmark table names |
417 | 590 | logging.info( |
418 | | - f"\n============List all benchmark result table names (Public and Private) below =================\n" |
| 591 | + "\n============List all benchmark result table names =================\n" |
419 | 592 | ) |
420 | 593 | logging.info( |
421 | | - f"\n============ public device benchmark results({len(public_ones)})=================\n" |
422 | | - ) |
423 | | - for name in public_ones: |
424 | | - logging.info(name) |
425 | | - logging.info( |
426 | | - f"\n======= private device benchmark results({len(private_ones)})=======\n" |
| 594 | + f"\n============ public device benchmark results({len(names)})=================\n" |
427 | 595 | ) |
428 | | - for name in private_ones: |
| 596 | + for name in names: |
429 | 597 | logging.info(name) |
430 | 598 |
|
431 | 599 | def _generate_table_name(self, group_info: dict, fields: list[str]) -> str: |
|
0 commit comments