Skip to content

Commit 1422422

Browse files
authored
delete all keys not in database
1 parent c8334e0 commit 1422422

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

scripts/compare-llama-bench.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
logger = logging.getLogger("compare-llama-bench")
2323

2424
# All llama-bench SQLite3 fields
25-
DB_FIELDS = [
25+
DB_FIELDS = {
2626
"build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename",
2727
"model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
2828
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
2929
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
3030
"use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth",
3131
"test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts",
32-
]
32+
}
3333

3434
# Properties by which to differentiate results per commit:
3535
KEY_PROPERTIES = [
@@ -306,10 +306,8 @@ def __init__(self, data_file: str):
306306
for i, line in enumerate(fp):
307307
parsed = json.loads(line)
308308

309-
if "samples_ns" in parsed:
310-
del parsed["samples_ns"]
311-
if "samples_ts" in parsed:
312-
del parsed["samples_ts"]
309+
for k in parsed.keys() - DB_FIELDS:
310+
del parsed[k]
313311

314312
if (missing_keys := self._check_keys(parsed.keys())):
315313
raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
@@ -328,10 +326,8 @@ def __init__(self, data_files: list[str]):
328326
parsed = json.load(fp)
329327

330328
for i, entry in enumerate(parsed):
331-
if "samples_ns" in entry:
332-
del entry["samples_ns"]
333-
if "samples_ts" in entry:
334-
del entry["samples_ts"]
329+
for k in entry.keys() - DB_FIELDS:
330+
del entry[k]
335331

336332
if (missing_keys := self._check_keys(entry.keys())):
337333
raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}")
@@ -348,7 +344,12 @@ def __init__(self, data_files: list[str]):
348344
for data_file in data_files:
349345
with open(data_file, "r", encoding="utf-8") as fp:
350346
for i, parsed in enumerate(csv.DictReader(fp)):
351-
if (missing_keys := self._check_keys(set(parsed.keys()))):
347+
keys = set(parsed.keys())
348+
349+
for k in keys - DB_FIELDS:
350+
del parsed[k]
351+
352+
if (missing_keys := self._check_keys(keys)):
352353
raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
353354

354355
# FIXME: Convert float/int columns from str!

0 commit comments

Comments
 (0)