88"""
99
1010import argparse
11- from copy import deepcopy
1211import json
1312import logging
1413import os
14+ import re
15+ from copy import deepcopy
1516from dataclasses import dataclass
1617from datetime import datetime
17- from typing import Any , Dict , List , Optional , Tuple
18- import re
1918from enum import Enum
19+ from typing import Any , Dict , List
20+
2021import pandas as pd
2122import requests
2223
2324logging .basicConfig (level = logging .INFO )
2425
26+
2527class OutputType (Enum ):
2628 """
2729 Enumeration of supported output formats for benchmark data.
@@ -33,12 +35,14 @@ class OutputType(Enum):
3335 JSON: Export data to JSON files
3436 DF: Return data as pandas DataFrames
3537 """
38+
3639 EXCEL = "excel"
3740 PRINT = "print"
3841 CSV = "csv"
3942 JSON = "json"
4043 DF = "df"
4144
45+
4246@dataclass
4347class BenchmarkQueryGroupDataParams :
4448 """
@@ -52,13 +56,15 @@ class BenchmarkQueryGroupDataParams:
5256 group_table_by_fields: Fields to group tables by
5357 group_row_by_fields: Fields to group rows by
5458 """
59+
5560 repo : str
5661 benchmark_name : str
5762 start_time : str
5863 end_time : str
5964 group_table_by_fields : list
6065 group_row_by_fields : list
6166
67+
6268@dataclass
6369class MatchingGroupResult :
6470 """
@@ -68,14 +74,17 @@ class MatchingGroupResult:
6874 category: Category name (e.g., "private", "public")
6975 data: List of benchmark data for this category
7076 """
77+
7178 category : str
7279 data : list
7380
81+
7482BASE_URLS = {
7583 "local" : "http://localhost:3000" ,
7684 "prod" : "https://hud.pytorch.org" ,
7785}
7886
87+
7988def validate_iso8601_no_ms (value : str ):
8089 """
8190 Validate that a string is in ISO8601 format without milliseconds.
@@ -91,6 +100,7 @@ def validate_iso8601_no_ms(value: str):
91100 f"Invalid datetime format for '{ value } '. Expected: YYYY-MM-DDTHH:MM:SS"
92101 )
93102
103+
94104class ExecutorchBenchmarkFetcher :
95105 """
96106 Fetch and process benchmark data from HUD API for ExecutorchBenchmark.
@@ -145,14 +155,15 @@ def run(
145155 self ,
146156 start_time : str ,
147157 end_time : str ,
148- ) -> Any :
158+ ) -> None :
149159 data = self ._fetch_execu_torch_data (start_time , end_time )
150160 if data is None :
151161 logging .warning ("no data fetched from the HUD API" )
152162 return None
163+
153164 res = self ._process (data )
154- self .data = res [ "data" ]
155- private_list = res [ "private" ]
165+ self .data = res . get ( "data" , [])
166+ private_list = res . get ( "private" , [])
156167 public_list = self ._filter_public_result (private_list , res ["public" ])
157168
158169 # reset group
@@ -163,7 +174,23 @@ def run(
163174 self .matching_groups ["public" ] = MatchingGroupResult (
164175 category = "public" , data = public_list
165176 )
166- return self .data
177+
178+ def _filter_out_failure_only (
179+ self , data_list : List [Dict [str , Any ]]
180+ ) -> List [Dict [str , Any ]]:
181+ """
182+ clean FAILURE_REPORT only metrics
183+ """
184+ ONLY = {"workflow_id" , "granularity_bucket" , "job_id" , "FAILURE_REPORT" }
185+ for item in data_list :
186+ filtered_rows = [
187+ row
188+ for row in item .get ("rows" , [])
189+ # Keep row only if it has additional fields beyond ONLY
190+ if not set (row .keys ()).issubset (ONLY )
191+ ]
192+ item ["rows" ] = filtered_rows
193+ return [item for item in data_list if item .get ("rows" )]
167194
168195 def _filter_public_result (self , private_list , public_list ):
169196 """
@@ -184,7 +211,10 @@ def _filter_public_result(self, private_list, public_list):
184211 set ([item ["table_name" ] for item in private_list ])
185212 & set ([item ["table_name" ] for item in public_list ])
186213 )
187- logging .info (f"common table name for both private and public { len (common )} " )
214+ logging .info (
215+ f"Found { len (common )} table names existed in both private and public, use it to filter public tables:"
216+ )
217+ logging .info (json .dumps (common , indent = 1 ))
188218 filtered_public = [item for item in public_list if item ["table_name" ] in common ]
189219 return filtered_public
190220
@@ -253,9 +283,7 @@ def output_data(
253283 Returns:
254284 Benchmark results in the specified format
255285 """
256- logging .info (
257- f"Generating output with type: { [category for category in self .matching_groups .keys ()]} "
258- )
286+ logging .info (f"Generating output with type: { [self .matching_groups .keys ()]} " )
259287 o_type = self ._to_output_type (output_type )
260288 if o_type == OutputType .PRINT :
261289 logging .info ("\n ========= Generate print output ========= \n " )
@@ -351,7 +379,10 @@ def to_df(self) -> Any:
351379 result = {}
352380 for item in self .matching_groups .values ():
353381 result [item .category ] = [
354- {"groupInfo" : item ["groupInfo" ], "df" : pd .DataFrame (item ["rows" ])}
382+ {
383+ "groupInfo" : item .get ("groupInfo" , {}),
384+ "df" : pd .DataFrame (item .get ("rows" , [])),
385+ }
355386 for item in item .data
356387 ]
357388 return result
@@ -472,7 +503,7 @@ def _process(self, input_data: List[Dict[str, Any]]):
472503 Process raw benchmark data.
473504
474505 This method:
475- 1. Normalizes string values into new field info
506+ 1. clean the data that generated by FAILURE_REPORT,
476507 2. Creates table_name from info
477508 3. Determines aws_type (public/private) based on info.device
478509 4. Sorts results by table_name
@@ -483,12 +514,7 @@ def _process(self, input_data: List[Dict[str, Any]]):
483514 """
484515 # filter data with arch equal exactly "",ios and android, this normally indicates it's job-level falure indicator
485516 logging .info (f"fetched { len (input_data )} data from HUD" )
486- data = [
487- item
488- for item in input_data
489- if (arch := item .get ("groupInfo" , {}).get ("arch" )) is not None
490- and arch .lower () not in ("ios" , "android" )
491- ]
517+ data = self ._clean_data (input_data )
492518
493519 private = []
494520 public = []
@@ -520,6 +546,17 @@ def _process(self, input_data: List[Dict[str, Any]]):
520546 )
521547 return {"data" : data , "private" : private , "public" : public }
522548
549+ def _clean_data (self , data_list ):
550+ removed_gen_arch = [
551+ item
552+ for item in data_list
553+ if (arch := item .get ("groupInfo" , {}).get ("arch" )) is not None
554+ and arch .lower () not in ("ios" , "android" )
555+ ]
556+
557+ data = self ._filter_out_failure_only (removed_gen_arch )
558+ return data
559+
523560 def _fetch_execu_torch_data (self , start_time , end_time ):
524561 url = f"{ self .base_url } /api/benchmark/group_data"
525562 params_object = BenchmarkQueryGroupDataParams (
@@ -569,9 +606,7 @@ def argparsers():
569606 parser .add_argument (
570607 "--env" , choices = ["local" , "prod" ], default = "prod" , help = "Environment"
571608 )
572- parser .add_argument (
573- "--silent" , action = "store_true" , help = "Disable logging"
574- )
609+ parser .add_argument ("--silent" , action = "store_true" , help = "Disable logging" )
575610
576611 # Options for generate_data
577612 parser .add_argument (
0 commit comments