2222logger = logging .getLogger ("compare-llama-bench" )
2323
2424# All llama-bench SQLite3 fields
25- DB_FIELDS = [
25+ DB_FIELDS = {
2626 "build_commit" , "build_number" , "cpu_info" , "gpu_info" , "backends" , "model_filename" ,
2727 "model_type" , "model_size" , "model_n_params" , "n_batch" , "n_ubatch" , "n_threads" ,
2828 "cpu_mask" , "cpu_strict" , "poll" , "type_k" , "type_v" , "n_gpu_layers" ,
2929 "split_mode" , "main_gpu" , "no_kv_offload" , "flash_attn" , "tensor_split" , "tensor_buft_overrides" ,
3030 "use_mmap" , "embeddings" , "no_op_offload" , "n_prompt" , "n_gen" , "n_depth" ,
3131 "test_time" , "avg_ns" , "stddev_ns" , "avg_ts" , "stddev_ts" ,
32- ]
32+ }
3333
3434# Properties by which to differentiate results per commit:
3535KEY_PROPERTIES = [
@@ -306,10 +306,8 @@ def __init__(self, data_file: str):
306306 for i , line in enumerate (fp ):
307307 parsed = json .loads (line )
308308
309- if "samples_ns" in parsed :
310- del parsed ["samples_ns" ]
311- if "samples_ts" in parsed :
312- del parsed ["samples_ts" ]
309+ for k in parsed .keys () - DB_FIELDS :
310+ del parsed [k ]
313311
314312 if (missing_keys := self ._check_keys (parsed .keys ())):
315313 raise RuntimeError (f"Missing required data key(s) at line { i + 1 } : { ', ' .join (missing_keys )} " )
@@ -328,10 +326,8 @@ def __init__(self, data_files: list[str]):
328326 parsed = json .load (fp )
329327
330328 for i , entry in enumerate (parsed ):
331- if "samples_ns" in entry :
332- del entry ["samples_ns" ]
333- if "samples_ts" in entry :
334- del entry ["samples_ts" ]
329+ for k in entry .keys () - DB_FIELDS :
330+ del entry [k ]
335331
336332 if (missing_keys := self ._check_keys (entry .keys ())):
337333 raise RuntimeError (f"Missing required data key(s) at entry { i + 1 } : { ', ' .join (missing_keys )} " )
@@ -348,7 +344,12 @@ def __init__(self, data_files: list[str]):
348344 for data_file in data_files :
349345 with open (data_file , "r" , encoding = "utf-8" ) as fp :
350346 for i , parsed in enumerate (csv .DictReader (fp )):
351- if (missing_keys := self ._check_keys (set (parsed .keys ()))):
347+ keys = set (parsed .keys ())
348+
349+ for k in keys - DB_FIELDS :
350+ del parsed [k ]
351+
352+ if (missing_keys := self ._check_keys (keys )):
352353 raise RuntimeError (f"Missing required data key(s) at line { i + 1 } : { ', ' .join (missing_keys )} " )
353354
354355 # FIXME: Convert float/int columns from str!
0 commit comments