Skip to content

Commit 6ad2a01

Browse files
authored
address change requests
1 parent f5097af commit 6ad2a01

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

scripts/compare-llama-bench.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
logger = logging.getLogger("compare-llama-bench")
2323

24-
# All llama-bench SQLite3 fields
24+
# All llama-bench SQL fields
2525
DB_FIELDS = [
2626
"build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename",
2727
"model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
@@ -41,6 +41,7 @@
4141
"INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
4242
"TEXT", "INTEGER", "INTEGER", "REAL", "REAL",
4343
]
44+
assert len(DB_FIELDS) == len(DB_TYPES)
4445

4546
# Properties by which to differentiate results per commit:
4647
KEY_PROPERTIES = [
@@ -95,8 +96,8 @@
9596
)
9697
parser.add_argument("-c", "--compare", help=help_c)
9798
help_i = (
98-
"Input JSONL/SQLite file or JSON/CSV files for comparing commits. "
99-
"Specify this argument multiple times for multiple files. "
99+
"JSON/JSONL/SQLite/CSV files for comparing commits. "
100+
"Specify multiple times to use multiple input files (JSON/CSV only). "
100101
"Defaults to 'llama-bench.sqlite' in the current working directory. "
101102
"If no such file is found and there is exactly one .sqlite file in the current directory, "
102103
"that file is instead used as input."
@@ -289,6 +290,7 @@ class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
289290
def __init__(self, data_file: str):
290291
super().__init__()
291292

293+
self.connection.close()
292294
self.connection = sqlite3.connect(data_file)
293295
self.cursor = self.connection.cursor()
294296
self._builds_init()
@@ -416,8 +418,12 @@ def valid_format(data_files: list[str]) -> bool:
416418
if len(input_file) == 1:
417419
if LlamaBenchDataSQLite3File.valid_format(input_file[0]):
418420
bench_data = LlamaBenchDataSQLite3File(input_file[0])
421+
elif LlamaBenchDataJSON.valid_format(input_file):
422+
bench_data = LlamaBenchDataJSON(input_file)
419423
elif LlamaBenchDataJSONL.valid_format(input_file[0]):
420424
bench_data = LlamaBenchDataJSONL(input_file[0])
425+
elif LlamaBenchDataCSV.valid_format(input_file):
426+
bench_data = LlamaBenchDataCSV(input_file)
421427
else:
422428
if LlamaBenchDataJSON.valid_format(input_file):
423429
bench_data = LlamaBenchDataJSON(input_file)

0 commit comments

Comments
 (0)