6767GPU_NAME_STRIP = ["NVIDIA GeForce " , "Tesla " , "AMD Radeon " ] # Strip prefixes for smaller tables.
6868MODEL_SUFFIX_REPLACE = {" - Small" : "_S" , " - Medium" : "_M" , " - Large" : "_L" }
6969
70- DESCRIPTION = """Creates tables from llama-bench data written to a JSONL file or SQLite database. Example usage (Linux):
70+ DESCRIPTION = """Creates tables from llama-bench data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux):
7171
7272$ git checkout master
7373$ make clean && make llama-bench
9595)
9696parser .add_argument ("-c" , "--compare" , help = help_c )
9797help_i = (
98- "Input JSONL/SQLite file for comparing commits. "
98+ "Input JSONL/SQLite file or JSON/CSV files for comparing commits. "
99+ "Specify this argument multiple times for multiple files. "
99100 "Defaults to 'llama-bench.sqlite' in the current working directory. "
100101 "If no such file is found and there is exactly one .sqlite file in the current directory, "
101102 "that file is instead used as input."
102103)
103- parser .add_argument ("-i" , "--input" , help = help_i )
104+ parser .add_argument ("-i" , "--input" , action = "append" , help = help_i )
104105help_o = (
105106 "Output format for the table. "
106107 "Defaults to 'pipe' (GitHub compatible). "
135136 sys .exit (1 )
136137
137138input_file = known_args .input
138- if input_file is None and os .path .exists ("./llama-bench.sqlite" ):
139- input_file = "llama-bench.sqlite"
140- if input_file is None :
139+ if not input_file and os .path .exists ("./llama-bench.sqlite" ):
140+ input_file = [ "llama-bench.sqlite" ]
141+ if not input_file :
141142 sqlite_files = glob ("*.sqlite" )
142143 if len (sqlite_files ) == 1 :
143- input_file = sqlite_files [ 0 ]
144+ input_file = sqlite_files
144145
145- if input_file is None :
146+ if not input_file :
146147 logger .error ("Cannot find a suitable input file, please provide one.\n " )
147148 parser .print_help ()
148149 sys .exit (1 )
@@ -285,48 +286,59 @@ def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare
285286
286287
287288class LlamaBenchDataSQLite3File (LlamaBenchDataSQLite3 ):
288- connected_file = False
289-
290289 def __init__ (self , data_file : str ):
291290 super ().__init__ ()
292291
292+ self .connection = sqlite3 .connect (data_file )
293+ self .cursor = self .connection .cursor ()
294+ self ._builds_init ()
295+
296+ @staticmethod
297+ def valid_format (data_file : str ) -> bool :
293298 connection = sqlite3 .connect (data_file )
294299 cursor = connection .cursor ()
295300
296- # Test if data_file is a valid SQLite database
297301 try :
298302 if cursor .execute ("PRAGMA schema_version;" ).fetchone ()[0 ] == 0 :
299- raise RuntimeError ("The provided input file does not exist or is empty." )
300- except sqlite3 .DatabaseError :
301- connection . close ( )
302- connection = None
303+ raise sqlite3 . DatabaseError ("The provided input file does not exist or is empty." )
304+ except sqlite3 .DatabaseError as e :
305+ logger . debug ( f'" { data_file } " is not a valid SQLite3 file.' , exc_info = e )
306+ cursor = None
303307
304- if (connection ):
305- self .connected_file = True
306- self .connection .close ()
307- self .connection = connection
308- self .cursor = cursor
309- self ._builds_init ()
308+ connection .close ()
309+ return True if cursor else False
310310
311311
312- class LlamaBenchDataSQLite3_or_JSONL ( LlamaBenchDataSQLite3File ):
312+ class LlamaBenchDataJSONL ( LlamaBenchDataSQLite3 ):
313313 def __init__ (self , data_file : str ):
314- super ().__init__ (data_file )
314+ super ().__init__ ()
315315
316- if not self .connected_file :
317- with open (data_file , "r" , encoding = "utf-8" ) as fp :
318- for i , line in enumerate (fp ):
319- parsed = json .loads (line )
316+ with open (data_file , "r" , encoding = "utf-8" ) as fp :
317+ for i , line in enumerate (fp ):
318+ parsed = json .loads (line )
320319
321- for k in parsed .keys () - set (DB_FIELDS ):
322- del parsed [k ]
320+ for k in parsed .keys () - set (DB_FIELDS ):
321+ del parsed [k ]
323322
324- if (missing_keys := self ._check_keys (parsed .keys ())):
325- raise RuntimeError (f"Missing required data key(s) at line { i + 1 } : { ', ' .join (missing_keys )} " )
323+ if (missing_keys := self ._check_keys (parsed .keys ())):
324+ raise RuntimeError (f"Missing required data key(s) at line { i + 1 } : { ', ' .join (missing_keys )} " )
326325
327- self .cursor .execute (f"INSERT INTO test({ ', ' .join (parsed .keys ())} ) VALUES({ ', ' .join ('?' * len (parsed ))} );" , tuple (parsed .values ()))
326+ self .cursor .execute (f"INSERT INTO test({ ', ' .join (parsed .keys ())} ) VALUES({ ', ' .join ('?' * len (parsed ))} );" , tuple (parsed .values ()))
328327
329- self ._builds_init ()
328+ self ._builds_init ()
329+
330+ @staticmethod
331+ def valid_format (data_file : str ) -> bool :
332+ try :
333+ with open (data_file , "r" , encoding = "utf-8" ) as fp :
334+ for line in fp :
335+ json .loads (line )
336+ break
337+ except Exception as e :
338+ logger .debug (f'"{ data_file } " is not a valid JSONL file.' , exc_info = e )
339+ return False
340+
341+ return True
330342
331343
332344class LlamaBenchDataJSON (LlamaBenchDataSQLite3 ):
@@ -348,6 +360,22 @@ def __init__(self, data_files: list[str]):
348360
349361 self ._builds_init ()
350362
363+ @staticmethod
364+ def valid_format (data_files : list [str ]) -> bool :
365+ if not data_files :
366+ return False
367+
368+ for data_file in data_files :
369+ try :
370+ with open (data_file , "r" , encoding = "utf-8" ) as fp :
371+ json .load (fp )
372+ continue
373+ except Exception as e :
374+ logger .debug (f'"{ data_file } " is not a valid JSON file.' , exc_info = e )
375+ return False
376+
377+ return True
378+
351379
352380class LlamaBenchDataCSV (LlamaBenchDataSQLite3 ):
353381 def __init__ (self , data_files : list [str ]):
@@ -368,8 +396,37 @@ def __init__(self, data_files: list[str]):
368396
369397 self ._builds_init ()
370398
399+ @staticmethod
400+ def valid_format (data_files : list [str ]) -> bool :
401+ if not data_files :
402+ return False
403+
404+ for data_file in data_files :
405+ try :
406+ with open (data_file , "r" , encoding = "utf-8" ) as fp :
407+ csv .DictReader (fp )
408+ continue
409+ except Exception as e :
410+ logger .debug (f'"{ data_file } " is not a valid CSV file.' , exc_info = e )
411+ return False
412+
413+ return True
414+
415+
416+ bench_data = None
417+ if len (input_file ) == 1 :
418+ if LlamaBenchDataSQLite3File .valid_format (input_file [0 ]):
419+ bench_data = LlamaBenchDataSQLite3File (input_file [0 ])
420+ elif LlamaBenchDataJSONL .valid_format (input_file [0 ]):
421+ bench_data = LlamaBenchDataJSONL (input_file [0 ])
422+ else :
423+ if LlamaBenchDataJSON .valid_format (input_file ):
424+ bench_data = LlamaBenchDataJSON (input_file )
425+ elif LlamaBenchDataCSV .valid_format (input_file ):
426+ bench_data = LlamaBenchDataCSV (input_file )
371427
372- bench_data = LlamaBenchDataSQLite3_or_JSONL (input_file )
428+ if not bench_data :
429+ raise RuntimeError ("No valid (or some invalid) input files found." )
373430
374431if not bench_data .builds :
375432 raise RuntimeError (f"{ input_file } does not contain any builds." )
0 commit comments