|
21 | 21 |
|
22 | 22 | logger = logging.getLogger("compare-llama-bench") |
23 | 23 |
|
24 | | -# All llama-bench SQLite3 fields |
| 24 | +# All llama-bench SQL fields |
25 | 25 | DB_FIELDS = [ |
26 | 26 | "build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename", |
27 | 27 | "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads", |
|
41 | 41 | "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", |
42 | 42 | "TEXT", "INTEGER", "INTEGER", "REAL", "REAL", |
43 | 43 | ] |
| 44 | +assert len(DB_FIELDS) == len(DB_TYPES) |
44 | 45 |
|
45 | 46 | # Properties by which to differentiate results per commit: |
46 | 47 | KEY_PROPERTIES = [ |
|
95 | 96 | ) |
96 | 97 | parser.add_argument("-c", "--compare", help=help_c) |
97 | 98 | help_i = ( |
98 | | - "Input JSONL/SQLite file or JSON/CSV files for comparing commits. " |
99 | | - "Specify this argument multiple times for multiple files. " |
| 99 | + "JSON/JSONL/SQLite/CSV files for comparing commits. " |
| 100 | + "Specify multiple times to use multiple input files (JSON/CSV only). " |
100 | 101 | "Defaults to 'llama-bench.sqlite' in the current working directory. " |
101 | 102 | "If no such file is found and there is exactly one .sqlite file in the current directory, " |
102 | 103 | "that file is instead used as input." |
@@ -289,6 +290,7 @@ class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3): |
289 | 290 | def __init__(self, data_file: str): |
290 | 291 | super().__init__() |
291 | 292 |
|
| 293 | + self.connection.close() |
292 | 294 | self.connection = sqlite3.connect(data_file) |
293 | 295 | self.cursor = self.connection.cursor() |
294 | 296 | self._builds_init() |
@@ -416,8 +418,12 @@ def valid_format(data_files: list[str]) -> bool: |
416 | 418 | if len(input_file) == 1: |
417 | 419 | if LlamaBenchDataSQLite3File.valid_format(input_file[0]): |
418 | 420 | bench_data = LlamaBenchDataSQLite3File(input_file[0]) |
| 421 | + elif LlamaBenchDataJSON.valid_format(input_file): |
| 422 | + bench_data = LlamaBenchDataJSON(input_file) |
419 | 423 | elif LlamaBenchDataJSONL.valid_format(input_file[0]): |
420 | 424 | bench_data = LlamaBenchDataJSONL(input_file[0]) |
| 425 | + elif LlamaBenchDataCSV.valid_format(input_file): |
| 426 | + bench_data = LlamaBenchDataCSV(input_file) |
421 | 427 | else: |
422 | 428 | if LlamaBenchDataJSON.valid_format(input_file): |
423 | 429 | bench_data = LlamaBenchDataJSON(input_file) |
|
0 commit comments