@@ -139,188 +139,6 @@ def argparser():
139139
140140 return parser .parse_args ()
141141
142-
143- class BenchmarkFetcher :
144- def __init__ (
145- self ,
146- env = "prod" ,
147- repo = "pytorch/pytorch" ,
148- benchmark_name = "" ,
149- disable_logging = False ,
150- group_table_fields = None ,
151- group_row_fields = None ,
152- processor_funcs = None ,
153- ):
154- """
155- Initialize the ExecutorchBenchmarkFetcher.
156-
157- Args:
158- env: Environment to use ("local" or "prod")
159- disable_logging: Whether to suppress log output
160- group_table_fields: Custom fields to group tables by (defaults to device, backend, arch, model)
161- group_row_fields: Custom fields to group rows by (defaults to workflow_id, job_id, granularity_bucket)
162- """
163- self .env = env
164- self .base_url = self ._get_base_url ()
165- self .query_group_table_by_fields = (
166- group_table_fields
167- if group_table_fields
168- else ["device" , "backend" , "arch" , "model" ]
169- )
170- self .query_group_row_by_fields = (
171- group_row_fields
172- if group_row_fields
173- else ["workflow_id" , "job_id" , "granularity_bucket" ]
174- )
175- self .data = None
176- self .disable_logging = disable_logging
177- self .processor_funcs = processor_funcs if processor_funcs else []
178-
179- self .repo = repo
180- self .benchmark_name = benchmark_name
181-
182- def _get_base_url (self ) -> str :
183- """
184- Get the base URL for API requests based on environment.
185-
186- Returns:
187- Base URL string for the configured environment
188- """
189- base_urls = {
190- "local" : "http://localhost:3000" ,
191- "prod" : "https://hud.pytorch.org" ,
192- }
193- return base_urls [self .env ]
194-
195- def _fetch_data (
196- self , start_time : str , end_time : str
197- ) -> Optional [List [Dict [str , Any ]]]:
198- """
199- Fetch and process benchmark data for the specified time range.
200-
201- Args:
202- start_time: ISO8601 formatted start time
203- end_time: ISO8601 formatted end time
204-
205- Returns:
206- Processed benchmark data or None if fetch failed
207- """
208- data = self ._fetch_data (start_time , end_time )
209- if data is None :
210- return None
211- self .data = self ._process (data )
212- return self .data
213-
214- def to_df (self ) -> Any :
215- if not self .data :
216- return
217- dfs = [
218- {"groupInfo" : item ["groupInfo" ], "df" : pd .DataFrame (item ["rows" ])}
219- for item in self .data
220- ]
221- return dfs
222-
223- def _process (self , data : List [Dict [str , Any ]]) -> List [Dict [str , Any ]]:
224- """
225- Process raw benchmark data.
226-
227- This method:
228- 1. Normalizes string values in groupInfo
229- 2. Creates table_name from group info components
230- 3. Determines aws_type (public/private) based on device name
231- 4. Sorts results by table_name
232-
233- Args:
234- data: Raw benchmark data from API
235-
236- Returns:
237- Processed benchmark data
238- """
239- for item in data :
240- # normalized string values in groupInfo
241- item ["groupInfo" ] = {
242- k : self .normalize_string (v )
243- for k , v in item .get ("groupInfo" , {}).items ()
244- if v is not None and isinstance (v , str )
245- }
246- group = item .get ("groupInfo" , {})
247- name = self ._generate_table_name (group , self .query_group_table_by_fields )
248-
249- # Add full name joined by the group key fields
250- item ["table_name" ] = name
251-
252- # Mark aws_type: private or public
253- if group .get ("device" , "" ).find ("private" ) != - 1 :
254- item ["groupInfo" ]["aws_type" ] = "private"
255- else :
256- item ["groupInfo" ]["aws_type" ] = "public"
257-
258- data .sort (key = lambda x : x ["table_name" ])
259- logging .info (f"fetched { len (data )} table views" )
260- return data
261-
262- def normalize_string (self , s , replace = "_" ):
263- return s .lower ().replace (" " , replace )
264-
265- def _generate_table_name (self , group_info : dict , fields : list [str ]) -> str :
266- name = "|" .join (
267- group_info [k ] for k in fields if k in group_info and group_info [k ]
268- )
269- return self .normalize_string (name )
270-
271- def _call_api (self , start_time , end_time ):
272- url = f"{ self .base_url } /api/benchmark/group_data"
273- params_object = BenchmarkQueryGroupDataParams (
274- repo = "pytorch/executorch" ,
275- benchmark_name = "ExecuTorch" ,
276- start_time = start_time ,
277- end_time = end_time ,
278- group_table_by_fields = self .query_group_table_by_fields ,
279- group_row_by_fields = self .query_group_row_by_fields ,
280- )
281- params = {k : v for k , v in params_object .__dict__ .items () if v is not None }
282- response = requests .get (url , params = params )
283- if response .status_code == 200 :
284- return response .json ()
285- else :
286- logging .info (f"Failed to fetch benchmark data ({ response .status_code } )" )
287- logging .info (response .text )
288- return None
289-
290- def _write_multiple_csv_files (
291- self , data_list : List [Dict [str , Any ]], output_dir : str , file_prefix = ""
292- ) -> None :
293- """
294- Write multiple benchmark results to separate CSV files.
295-
296- Each entry in `data_list` becomes its own CSV file.
297-
298- Args:
299- data_list: List of benchmark result dictionaries
300- output_dir: Directory to save the CSV files
301- """
302- os .makedirs (output_dir , exist_ok = True )
303- logging .info (
304- f"\n ========= Generating multiple CSV files in { output_dir } ========= \n "
305- )
306- for idx , entry in enumerate (data_list ):
307- file_name = entry .get ("short_name" , f"file{ idx + 1 } " )
308-
309- if file_prefix :
310- file_name = file_prefix + file_name
311- if len (file_name ) > 100 :
312- logging .warning (
313- f"File name '{ file_name } ' is too long, truncating to 100 characters"
314- )
315- file_name = file_name [:100 ]
316- file_path = os .path .join (output_dir , f"{ file_name } .csv" )
317-
318- rows = entry .get ("rows" , [])
319- logging .info (f"Writing CSV: { file_path } with { len (rows )} rows" )
320- df = pd .DataFrame (rows )
321- df .to_csv (file_path , index = False )
322-
323-
324142class ExecutorchBenchmarkFetcher :
325143 """
326144 Fetch and process benchmark data from HUD API for ExecutorchBenchmark.
0 commit comments