99import sqlite3
1010import json
1111import csv
12- from functools import reduce
13- from itertools import groupby
14- from statistics import fmean
1512from typing import Optional , Union
1613from collections .abc import Iterator , Sequence
1714
2421
2522logger = logging .getLogger ("compare-llama-bench" )
2623
24+ # All llama-bench SQLite3 fields
25+ DB_FIELDS = [
26+ "build_commit" , "build_number" , "cpu_info" , "gpu_info" , "backends" , "model_filename" ,
27+ "model_type" , "model_size" , "model_n_params" , "n_batch" , "n_ubatch" , "n_threads" ,
28+ "cpu_mask" , "cpu_strict" , "poll" , "type_k" , "type_v" , "n_gpu_layers" ,
29+ "split_mode" , "main_gpu" , "no_kv_offload" , "flash_attn" , "tensor_split" , "tensor_buft_overrides" ,
30+ "use_mmap" , "embeddings" , "no_op_offload" , "n_prompt" , "n_gen" , "n_depth" ,
31+ "test_time" , "avg_ns" , "stddev_ns" , "avg_ts" , "stddev_ts" ,
32+ ]
33+
2734# Properties by which to differentiate results per commit:
2835KEY_PROPERTIES = [
2936 "cpu_info" , "gpu_info" , "backends" , "n_gpu_layers" , "tensor_buft_overrides" , "model_filename" , "model_type" ,
@@ -136,6 +143,7 @@ class LlamaBenchData:
136143 build_len_max : int
137144 build_len : int = 8
138145 builds : list [str ] = []
146+ check_keys = set (KEY_PROPERTIES + ["build_commit" , "test_time" , "avg_ts" ])
139147
140148 def __init__ (self ):
141149 try :
@@ -146,6 +154,12 @@ def __init__(self):
146154 def _builds_init (self ):
147155 self .build_len = self .build_len_min
148156
157+ def _check_keys (self , keys : set ) -> Optional [set ]:
158+ """Private helper method that checks against required data keys and returns missing ones."""
159+ if not keys >= self .check_keys :
160+ return self .check_keys - keys
161+ return None
162+
149163 def find_parent_in_data (self , commit : git .Commit ) -> Optional [str ]:
150164 """Helper method to find the most recent parent measured in number of commits for which there is data."""
151165 heap : list [tuple [int , git .Commit ]] = [(0 , commit )]
@@ -217,79 +231,117 @@ def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare
217231 return []
218232
219233
220- class LlamaBenchDataGeneric (LlamaBenchData ):
221- data : list
222- check_keys = set ( KEY_PROPERTIES + [ "build_commit" , "test_time" , "avg_ts" ])
234+ class LlamaBenchDataSQLite3 (LlamaBenchData ):
235+ connection : sqlite3 . Connection
236+ cursor : sqlite3 . Cursor
223237
224238 def __init__ (self ):
225239 super ().__init__ ()
226- self .data = []
227-
228- def _check_keys (self , keys : set ) -> Optional [set ]:
229- if not keys >= self .check_keys :
230- return self .check_keys - keys
231- return None
240+ self .connection = sqlite3 .connect (":memory:" )
241+ self .cursor = self .connection .cursor ()
242+ self .cursor .execute (f"CREATE TABLE test({ ', ' .join (DB_FIELDS )} );" )
232243
233244 def _builds_init (self ):
234- self .build_len_min , self .build_len_max = reduce (lambda x , y : (min (x [0 ], y ), max (x [1 ], y )), (len (d ["build_commit" ]) for d in self .data ), (1000 , 0 ))
235- self .builds = list (set (d ["build_commit" ] for d in self .data ))
245+ if self .connection :
246+ self .build_len_min = self .cursor .execute ("SELECT MIN(LENGTH(build_commit)) from test;" ).fetchone ()[0 ]
247+ self .build_len_max = self .cursor .execute ("SELECT MAX(LENGTH(build_commit)) from test;" ).fetchone ()[0 ]
248+
249+ if self .build_len_min != self .build_len_max :
250+ logger .warning ("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
251+ "Try purging the the database of old commits." )
252+ self .cursor .execute (f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, { self .build_len_min } );" )
253+
254+ builds = self .cursor .execute ("SELECT DISTINCT build_commit FROM test;" ).fetchall ()
255+ self .builds = list (map (lambda b : b [0 ], builds )) # list[tuple[str]] -> list[str]
236256 super ()._builds_init ()
237257
238258 def builds_timestamp (self , reverse : bool = False ) -> Union [Iterator [tuple ], Sequence [tuple ]]:
239- return sorted (((d ["build_commit" ], d ["test_time" ]) for d in self .data ), key = lambda x : x [1 ], reverse = reverse )
259+ data = self .cursor .execute (
260+ "SELECT build_commit, test_time FROM test ORDER BY test_time;" ).fetchall ()
261+ return reversed (data ) if reverse else data
240262
241263 def get_rows (self , properties : list [str ], hexsha8_baseline : str , hexsha8_compare : str ) -> Sequence [tuple ]:
242- select_data = []
243- join_equal = lambda x : tuple (x [p ] for p in KEY_PROPERTIES ) # noqa: E731
244- group_order = lambda x : tuple (x [p ] for p in properties + ["n_gen" , "n_prompt" , "n_depth" ]) # noqa: E731
245- for _ , g in groupby (sorted (self .data , key = group_order ), key = group_order ):
246- g = list (g )
247- join_on = {}
248- for row in filter (lambda x : x ["build_commit" ] == hexsha8_baseline , g ):
249- if (join_row := join_equal (row )) not in join_on :
250- row_copy = row .copy ()
251- row_copy ["avg_ts" ] = []
252- join_on [join_row ] = row_copy
253- join_on [join_row ]["avg_ts" ].append (row ["avg_ts" ])
254- joined = {}
255- for row in filter (lambda x : x ["build_commit" ] == hexsha8_compare , g ):
256- if (join_row := join_equal (row )) in join_on :
257- joined .setdefault (join_row , join_on [join_row ]).setdefault ("tc.avg_ts" , []).append (row ["avg_ts" ])
258- for row in joined .values ():
259- select_data .append (tuple (row [p ] for p in properties + ["n_prompt" , "n_gen" , "n_depth" ]) + (fmean (row ["avg_ts" ]), fmean (row ["tc.avg_ts" ])))
260- return select_data
261-
262-
263- class LlamaBenchDataJSONL (LlamaBenchDataGeneric ):
264+ select_string = ", " .join (
265+ [f"tb.{ p } " for p in properties ] + ["tb.n_prompt" , "tb.n_gen" , "tb.n_depth" , "AVG(tb.avg_ts)" , "AVG(tc.avg_ts)" ])
266+ equal_string = " AND " .join (
267+ [f"tb.{ p } = tc.{ p } " for p in KEY_PROPERTIES ] + [
268+ f"tb.build_commit = '{ hexsha8_baseline } '" , f"tc.build_commit = '{ hexsha8_compare } '" ]
269+ )
270+ group_order_string = ", " .join ([f"tb.{ p } " for p in properties ] + ["tb.n_gen" , "tb.n_prompt" , "tb.n_depth" ])
271+ query = (f"SELECT { select_string } FROM test tb JOIN test tc ON { equal_string } "
272+ f"GROUP BY { group_order_string } ORDER BY { group_order_string } ;" )
273+ return self .cursor .execute (query ).fetchall ()
274+
275+
276+ class LlamaBenchDataSQLite3File (LlamaBenchDataSQLite3 ):
277+ connected_file = False
278+
264279 def __init__ (self , data_file : str ):
265280 super ().__init__ ()
266281
267- with open (data_file , "r" , encoding = "utf-8" ) as fp :
268- for i , line in enumerate (fp ):
269- parsed = json .loads (line )
270- if (missing_keys := self ._check_keys (parsed .keys ())):
271- raise RuntimeError (f"Missing required data key(s) at line { i + 1 } : { ', ' .join (missing_keys )} " )
272- self .data .append (parsed )
282+ connection = sqlite3 .connect (data_file )
283+ cursor = connection .cursor ()
273284
274- self ._builds_init ()
285+ # Test if data_file is a valid SQLite database
286+ try :
287+ if cursor .execute ("PRAGMA schema_version;" ).fetchone ()[0 ] == 0 :
288+ raise RuntimeError ("The provided input file does not exist or is empty." )
289+ except sqlite3 .DatabaseError :
290+ connection .close ()
291+ connection = None
292+
293+ if (connection ):
294+ self .connected_file = True
295+ self .connection = connection
296+ self .cursor = cursor
297+ self ._builds_init ()
298+
299+
300+ class LlamaBenchDataSQLite3_or_JSONL (LlamaBenchDataSQLite3File ):
301+ def __init__ (self , data_file : str ):
302+ super ().__init__ (data_file )
303+
304+ if not self .connected_file :
305+ with open (data_file , "r" , encoding = "utf-8" ) as fp :
306+ for i , line in enumerate (fp ):
307+ parsed = json .loads (line )
308+
309+ if "samples_ns" in parsed :
310+ del parsed ["samples_ns" ]
311+ if "samples_ts" in parsed :
312+ del parsed ["samples_ts" ]
275313
314+ if (missing_keys := self ._check_keys (parsed .keys ())):
315+ raise RuntimeError (f"Missing required data key(s) at line { i + 1 } : { ', ' .join (missing_keys )} " )
316+
317+ self .cursor .execute (f"INSERT INTO test({ ', ' .join (parsed .keys ())} ) VALUES({ ', ' .join ('?' * len (parsed ))} );" , tuple (parsed .values ()))
318+
319+ self ._builds_init ()
276320
277- class LlamaBenchDataJSON (LlamaBenchDataGeneric ):
321+
322+ class LlamaBenchDataJSON (LlamaBenchDataSQLite3 ):
278323 def __init__ (self , data_files : list [str ]):
279324 super ().__init__ ()
280325
281326 for data_file in data_files :
282327 with open (data_file , "r" , encoding = "utf-8" ) as fp :
283328 parsed = json .load (fp )
329+
284330 for i , entry in enumerate (parsed ):
331+ if "samples_ns" in entry :
332+ del entry ["samples_ns" ]
333+ if "samples_ts" in entry :
334+ del entry ["samples_ts" ]
335+
285336 if (missing_keys := self ._check_keys (entry .keys ())):
286337 raise RuntimeError (f"Missing required data key(s) at entry { i + 1 } : { ', ' .join (missing_keys )} " )
287- self .data += parsed
338+
339+ self .cursor .execute (f"INSERT INTO test({ ', ' .join (entry .keys ())} ) VALUES({ ', ' .join ('?' * len (entry ))} );" , tuple (entry .values ()))
288340
289341 self ._builds_init ()
290342
291343
292- class LlamaBenchDataCSV (LlamaBenchDataGeneric ):
344+ class LlamaBenchDataCSV (LlamaBenchDataSQLite3 ):
293345 def __init__ (self , data_files : list [str ]):
294346 super ().__init__ ()
295347
@@ -298,72 +350,14 @@ def __init__(self, data_files: list[str]):
298350 for i , parsed in enumerate (csv .DictReader (fp )):
299351 if (missing_keys := self ._check_keys (set (parsed .keys ()))):
300352 raise RuntimeError (f"Missing required data key(s) at line { i + 1 } : { ', ' .join (missing_keys )} " )
353+
301354 # FIXME: Convert float/int columns from str!
302- self .data . append ( parsed )
355+ self .cursor . execute ( f"INSERT INTO test( { ', ' . join ( parsed . keys ()) } ) VALUES( { ', ' . join ( '?' * len ( parsed )) } );" , tuple ( parsed . values ()) )
303356
304357 self ._builds_init ()
305358
306359
307- class LlamaBenchDataSQLite3 (LlamaBenchData ):
308- connection : Optional [sqlite3 .Connection ] = None
309- cursor : sqlite3 .Cursor
310-
311- def __init__ (self , data_file : str ):
312- super ().__init__ ()
313-
314- connection = sqlite3 .connect (data_file )
315- cursor = connection .cursor ()
316-
317- # Test if data_file is a valid SQLite database
318- try :
319- if cursor .execute ("PRAGMA schema_version;" ).fetchone ()[0 ] == 0 :
320- raise RuntimeError ("The provided input file does not exist or is empty." )
321- except sqlite3 .DatabaseError :
322- connection .close ()
323- connection = None
324-
325- if (connection ):
326- self .connection = connection
327- self .cursor = cursor
328-
329- self .build_len_min = cursor .execute ("SELECT MIN(LENGTH(build_commit)) from test;" ).fetchone ()[0 ]
330- self .build_len_max = cursor .execute ("SELECT MAX(LENGTH(build_commit)) from test;" ).fetchone ()[0 ]
331-
332- if self .build_len_min != self .build_len_max :
333- logger .warning (f"{ data_file } contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
334- "Try purging the the database of old commits." )
335- cursor .execute (f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, { self .build_len_min } );" )
336-
337- self ._builds_init ()
338-
339- def _builds_init (self ):
340- if self .connection :
341- builds = self .cursor .execute ("SELECT DISTINCT build_commit FROM test;" ).fetchall ()
342- self .builds = list (map (lambda b : b [0 ], builds )) # list[tuple[str]] -> list[str]
343- super ()._builds_init ()
344-
345- def builds_timestamp (self , reverse : bool = False ) -> Union [Iterator [tuple ], Sequence [tuple ]]:
346- data = self .cursor .execute (
347- "SELECT build_commit, test_time FROM test ORDER BY test_time;" ).fetchall ()
348- return reversed (data ) if reverse else data
349-
350- def get_rows (self , properties : list [str ], hexsha8_baseline : str , hexsha8_compare : str ) -> Sequence [tuple ]:
351- select_string = ", " .join (
352- [f"tb.{ p } " for p in properties ] + ["tb.n_prompt" , "tb.n_gen" , "tb.n_depth" , "AVG(tb.avg_ts)" , "AVG(tc.avg_ts)" ])
353- equal_string = " AND " .join (
354- [f"tb.{ p } = tc.{ p } " for p in KEY_PROPERTIES ] + [
355- f"tb.build_commit = '{ hexsha8_baseline } '" , f"tc.build_commit = '{ hexsha8_compare } '" ]
356- )
357- group_order_string = ", " .join ([f"tb.{ p } " for p in properties ] + ["tb.n_gen" , "tb.n_prompt" , "tb.n_depth" ])
358- query = (f"SELECT { select_string } FROM test tb JOIN test tc ON { equal_string } "
359- f"GROUP BY { group_order_string } ORDER BY { group_order_string } ;" )
360- return self .cursor .execute (query ).fetchall ()
361-
362-
363- bench_data = LlamaBenchDataSQLite3 (input_file )
364- if not bench_data .connection :
365- # Not a SQLite database, try JSONL instead
366- bench_data = LlamaBenchDataJSONL (input_file )
360+ bench_data = LlamaBenchDataSQLite3_or_JSONL (input_file )
367361
368362if not bench_data .builds :
369363 raise RuntimeError (f"{ input_file } does not contain any builds." )
0 commit comments