Skip to content

Commit 295dcca

Browse files
authored
support multiple files and file type checking
1 parent ff1dee1 commit 295dcca

File tree

1 file changed

+91
-34
lines changed

1 file changed

+91
-34
lines changed

scripts/compare-llama-bench.py

Lines changed: 91 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables.
6868
MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"}
6969

70-
DESCRIPTION = """Creates tables from llama-bench data written to a JSONL file or SQLite database. Example usage (Linux):
70+
DESCRIPTION = """Creates tables from llama-bench data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux):
7171
7272
$ git checkout master
7373
$ make clean && make llama-bench
@@ -95,12 +95,13 @@
9595
)
9696
parser.add_argument("-c", "--compare", help=help_c)
9797
help_i = (
98-
"Input JSONL/SQLite file for comparing commits. "
98+
"Input JSONL/SQLite file or JSON/CSV files for comparing commits. "
99+
"Specify this argument multiple times for multiple files. "
99100
"Defaults to 'llama-bench.sqlite' in the current working directory. "
100101
"If no such file is found and there is exactly one .sqlite file in the current directory, "
101102
"that file is instead used as input."
102103
)
103-
parser.add_argument("-i", "--input", help=help_i)
104+
parser.add_argument("-i", "--input", action="append", help=help_i)
104105
help_o = (
105106
"Output format for the table. "
106107
"Defaults to 'pipe' (GitHub compatible). "
@@ -135,14 +136,14 @@
135136
sys.exit(1)
136137

137138
input_file = known_args.input
138-
if input_file is None and os.path.exists("./llama-bench.sqlite"):
139-
input_file = "llama-bench.sqlite"
140-
if input_file is None:
139+
if not input_file and os.path.exists("./llama-bench.sqlite"):
140+
input_file = ["llama-bench.sqlite"]
141+
if not input_file:
141142
sqlite_files = glob("*.sqlite")
142143
if len(sqlite_files) == 1:
143-
input_file = sqlite_files[0]
144+
input_file = sqlite_files
144145

145-
if input_file is None:
146+
if not input_file:
146147
logger.error("Cannot find a suitable input file, please provide one.\n")
147148
parser.print_help()
148149
sys.exit(1)
@@ -285,48 +286,59 @@ def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare
285286

286287

287288
class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
288-
connected_file = False
289-
290289
def __init__(self, data_file: str):
291290
super().__init__()
292291

292+
self.connection = sqlite3.connect(data_file)
293+
self.cursor = self.connection.cursor()
294+
self._builds_init()
295+
296+
@staticmethod
297+
def valid_format(data_file: str) -> bool:
293298
connection = sqlite3.connect(data_file)
294299
cursor = connection.cursor()
295300

296-
# Test if data_file is a valid SQLite database
297301
try:
298302
if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0:
299-
raise RuntimeError("The provided input file does not exist or is empty.")
300-
except sqlite3.DatabaseError:
301-
connection.close()
302-
connection = None
303+
raise sqlite3.DatabaseError("The provided input file does not exist or is empty.")
304+
except sqlite3.DatabaseError as e:
305+
logger.debug(f'"{data_file}" is not a valid SQLite3 file.', exc_info=e)
306+
cursor = None
303307

304-
if (connection):
305-
self.connected_file = True
306-
self.connection.close()
307-
self.connection = connection
308-
self.cursor = cursor
309-
self._builds_init()
308+
connection.close()
309+
return True if cursor else False
310310

311311

312-
class LlamaBenchDataSQLite3_or_JSONL(LlamaBenchDataSQLite3File):
312+
class LlamaBenchDataJSONL(LlamaBenchDataSQLite3):
313313
def __init__(self, data_file: str):
314-
super().__init__(data_file)
314+
super().__init__()
315315

316-
if not self.connected_file:
317-
with open(data_file, "r", encoding="utf-8") as fp:
318-
for i, line in enumerate(fp):
319-
parsed = json.loads(line)
316+
with open(data_file, "r", encoding="utf-8") as fp:
317+
for i, line in enumerate(fp):
318+
parsed = json.loads(line)
320319

321-
for k in parsed.keys() - set(DB_FIELDS):
322-
del parsed[k]
320+
for k in parsed.keys() - set(DB_FIELDS):
321+
del parsed[k]
323322

324-
if (missing_keys := self._check_keys(parsed.keys())):
325-
raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
323+
if (missing_keys := self._check_keys(parsed.keys())):
324+
raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
326325

327-
self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
326+
self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
328327

329-
self._builds_init()
328+
self._builds_init()
329+
330+
@staticmethod
331+
def valid_format(data_file: str) -> bool:
332+
try:
333+
with open(data_file, "r", encoding="utf-8") as fp:
334+
for line in fp:
335+
json.loads(line)
336+
break
337+
except Exception as e:
338+
logger.debug(f'"{data_file}" is not a valid JSONL file.', exc_info=e)
339+
return False
340+
341+
return True
330342

331343

332344
class LlamaBenchDataJSON(LlamaBenchDataSQLite3):
@@ -348,6 +360,22 @@ def __init__(self, data_files: list[str]):
348360

349361
self._builds_init()
350362

363+
@staticmethod
364+
def valid_format(data_files: list[str]) -> bool:
365+
if not data_files:
366+
return False
367+
368+
for data_file in data_files:
369+
try:
370+
with open(data_file, "r", encoding="utf-8") as fp:
371+
json.load(fp)
372+
continue
373+
except Exception as e:
374+
logger.debug(f'"{data_file}" is not a valid JSON file.', exc_info=e)
375+
return False
376+
377+
return True
378+
351379

352380
class LlamaBenchDataCSV(LlamaBenchDataSQLite3):
353381
def __init__(self, data_files: list[str]):
@@ -368,8 +396,37 @@ def __init__(self, data_files: list[str]):
368396

369397
self._builds_init()
370398

399+
@staticmethod
400+
def valid_format(data_files: list[str]) -> bool:
401+
if not data_files:
402+
return False
403+
404+
for data_file in data_files:
405+
try:
406+
with open(data_file, "r", encoding="utf-8") as fp:
407+
csv.DictReader(fp)
408+
continue
409+
except Exception as e:
410+
logger.debug(f'"{data_file}" is not a valid CSV file.', exc_info=e)
411+
return False
412+
413+
return True
414+
415+
416+
bench_data = None
417+
if len(input_file) == 1:
418+
if LlamaBenchDataSQLite3File.valid_format(input_file[0]):
419+
bench_data = LlamaBenchDataSQLite3File(input_file[0])
420+
elif LlamaBenchDataJSONL.valid_format(input_file[0]):
421+
bench_data = LlamaBenchDataJSONL(input_file[0])
422+
else:
423+
if LlamaBenchDataJSON.valid_format(input_file):
424+
bench_data = LlamaBenchDataJSON(input_file)
425+
elif LlamaBenchDataCSV.valid_format(input_file):
426+
bench_data = LlamaBenchDataCSV(input_file)
371427

372-
bench_data = LlamaBenchDataSQLite3_or_JSONL(input_file)
428+
if not bench_data:
429+
raise RuntimeError("No valid (or some invalid) input files found.")
373430

374431
if not bench_data.builds:
375432
raise RuntimeError(f"{input_file} does not contain any builds.")

0 commit comments

Comments
 (0)