77and customizing data retrieval parameters.
88"""
99
10- from yaspin import yaspin
1110import argparse
1211import json
1312import logging
2120
2221import pandas as pd
2322import requests
23+ from yaspin import yaspin
2424
2525logging .basicConfig (level = logging .INFO )
2626
@@ -80,6 +80,13 @@ class MatchingGroupResult:
8080 data : list
8181
8282
83+ @dataclass
84+ class BenchmarkFilters :
85+ models : list
86+ backends : list
87+ devices : list
88+
89+
8390BASE_URLS = {
8491 "local" : "http://localhost:3000" ,
8592 "prod" : "https://hud.pytorch.org" ,
@@ -156,19 +163,21 @@ def run(
156163 self ,
157164 start_time : str ,
158165 end_time : str ,
166+ filters : BenchmarkFilters ,
159167 ) -> None :
168+ # reset group & raw data for new run
169+ self .matching_groups = {}
170+ self .data = None
171+
160172 data = self ._fetch_execu_torch_data (start_time , end_time )
161173 if data is None :
162174 logging .warning ("no data fetched from the HUD API" )
163175 return None
164-
165- res = self ._process (data )
176+ res = self ._process (data , filters )
166177 self .data = res .get ("data" , [])
167178 private_list = res .get ("private" , [])
168179 public_list = self ._filter_public_result (private_list , res ["public" ])
169180
170- # reset group
171- self .matching_groups = {}
172181 self .matching_groups ["private" ] = MatchingGroupResult (
173182 category = "private" , data = private_list
174183 )
@@ -456,13 +465,18 @@ def print_all_groups_info(self) -> None:
456465 if not self .data or not self .matching_groups :
457466 logging .info ("No data found, please call get_data() first" )
458467 return
459- logging .info (f" all clean benchmark table info from HUD" )
468+ logging .info (
469+ "=========== Full list of table info from HUD API =============\n "
470+ " please use values in field `info` for filtering, "
471+ "while `groupInfo` holds the original benchmark metadata"
472+ )
460473 names = []
461474 for item in self .data :
462475 names .append (
463476 {
464477 "table_name" : item .get ("table_name" , "" ),
465- "groupInfo" : item .get ("groupInfo" , "" ),
478+ "groupInfo" : item .get ("groupInfo" , {}),
479+ "info" : item .get ("info" , {}),
466480 "counts" : len (item .get ("rows" , [])),
467481 }
468482 )
@@ -492,7 +506,7 @@ def _generate_matching_name(self, group_info: dict, fields: list[str]) -> str:
492506 # name = name +'(private)'
493507 return name
494508
495- def _process (self , input_data : List [Dict [str , Any ]]):
509+ def _process (self , input_data : List [Dict [str , Any ]], filters : BenchmarkFilters ):
496510 """
497511 Process raw benchmark data.
498512
@@ -509,9 +523,9 @@ def _process(self, input_data: List[Dict[str, Any]]):
509523 # filter data with arch equal exactly "",ios and android, this normally indicates it's job-level falure indicator
510524 logging .info (f"fetched { len (input_data )} data from HUD" )
511525 data = self ._clean_data (input_data )
512-
513526 private = []
514527 public = []
528+
515529 for item in data :
516530 # normalized string values groupInfo to info
517531 item ["info" ] = {
@@ -528,17 +542,30 @@ def _process(self, input_data: List[Dict[str, Any]]):
528542 # Mark aws_type: private or public
529543 if group .get ("device" , "" ).find ("private" ) != - 1 :
530544 item ["info" ]["aws_type" ] = "private"
531- private .append (item )
532545 else :
533546 item ["info" ]["aws_type" ] = "public"
534547 public .append (item )
535- data .sort (key = lambda x : x ["table_name" ])
536- private .sort (key = lambda x : x ["table_name" ])
537- public .sort (key = lambda x : x ["table_name" ])
548+ raw_data = deepcopy (data )
549+
550+ # applies customized filters if any
551+ data = self .filter_results (data , filters )
552+ # generate private and public results
553+ private = sorted (
554+ (
555+ item
556+ for item in data
557+ if item .get ("info" , {}).get ("aws_type" ) == "private"
558+ ),
559+ key = lambda x : x ["table_name" ],
560+ )
561+ public = sorted (
562+ (item for item in data if item .get ("info" , {}).get ("aws_type" ) == "public" ),
563+ key = lambda x : x ["table_name" ],
564+ )
538565 logging .info (
539566 f"fetched clean data { len (data )} , private:{ len (private )} , public:{ len (public )} "
540567 )
541- return {"data" : data , "private" : private , "public" : public }
568+ return {"data" : raw_data , "private" : private , "public" : public }
542569
543570 def _clean_data (self , data_list ):
544571 removed_gen_arch = [
@@ -575,6 +602,7 @@ def _fetch_execu_torch_data(self, start_time, end_time):
575602
576603 def normalize_string (self , s : str ) -> str :
577604 s = s .lower ().strip ()
605+ s = s .replace ("+" ,"plus" )
578606 s = s .replace ("_" , "-" )
579607 s = s .replace (" " , "-" )
580608 s = re .sub (r"[^\w\-\.\(\)]" , "-" , s )
@@ -583,6 +611,37 @@ def normalize_string(self, s: str) -> str:
583611 s = s .replace (")-" , ")" ).replace ("-)" , ")" )
584612 return s
585613
614+ def filter_results (self , data : List , filters : BenchmarkFilters ):
615+ backends = filters .backends
616+ devices = filters .devices
617+ models = filters .models
618+
619+ if not backends and not devices and not models :
620+ return data
621+ logging .info (
622+ f"applies OR filter: backends { backends } , devices:{ devices } ,models:{ models } "
623+ )
624+ pre_len = len (data )
625+ results = []
626+ for item in data :
627+ info = item .get ("info" , {})
628+ if backends and info .get ("backend" ) not in backends :
629+ continue
630+ if devices and not any (dev in info .get ("device" , "" ) for dev in devices ):
631+ continue
632+ if models and info .get ("model" , "" ) not in models :
633+ continue
634+ results .append (item )
635+ after_len = len (results )
636+ logging .info (f"applied customized filter before: { pre_len } , after: { after_len } " )
637+ if after_len == 0 :
638+ logging .info (
639+ "it seems like there is no result matches the filter values"
640+ ", please run script --no-silent again, and search for values in field"
641+ " 'info' for right format"
642+ )
643+ return results
644+
586645
587646def argparsers ():
588647 parser = argparse .ArgumentParser (description = "Benchmark Analysis Runner" )
@@ -622,7 +681,17 @@ def argparsers():
622681 parser .add_argument (
623682 "--outputDir" , default = "." , help = "Output directory, default is ."
624683 )
625-
684+ parser .add_argument (
685+ "--backends" ,
686+ nargs = "+" ,
687+ help = "Filter results by one or more backend full name(e.g. --backend qlora mv3) (OR logic)" ,
688+ )
689+ parser .add_argument (
690+ "--devices" ,
691+ nargs = "+" ,
692+ help = "Filter results by device names (e.g. --devices samsung-galaxy-s22-5g)(OR logic)" ,
693+ )
694+ parser .add_argument ("--models" , nargs = "+" , help = "Filter by models (OR logic)" )
626695 return parser .parse_args ()
627696
628697
@@ -632,6 +701,9 @@ def argparsers():
632701 result = fetcher .run (
633702 args .startTime ,
634703 args .endTime ,
704+ filters = BenchmarkFilters (
705+ models = args .models , backends = args .backends , devices = args .devices
706+ ),
635707 )
636708 if not args .silent :
637709 fetcher .print_all_groups_info ()
0 commit comments