Skip to content

Commit c8334e0

Browse files
authored
refactor into in-memory database
1 parent f5dd8df commit c8334e0

File tree

1 file changed

+102
-108
lines changed

1 file changed

+102
-108
lines changed

scripts/compare-llama-bench.py

Lines changed: 102 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
import sqlite3
1010
import json
1111
import csv
12-
from functools import reduce
13-
from itertools import groupby
14-
from statistics import fmean
1512
from typing import Optional, Union
1613
from collections.abc import Iterator, Sequence
1714

@@ -24,6 +21,16 @@
2421

2522
logger = logging.getLogger("compare-llama-bench")
2623

24+
# All llama-bench SQLite3 fields
25+
DB_FIELDS = [
26+
"build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename",
27+
"model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
28+
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
29+
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
30+
"use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth",
31+
"test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts",
32+
]
33+
2734
# Properties by which to differentiate results per commit:
2835
KEY_PROPERTIES = [
2936
"cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type",
@@ -136,6 +143,7 @@ class LlamaBenchData:
136143
build_len_max: int
137144
build_len: int = 8
138145
builds: list[str] = []
146+
check_keys = set(KEY_PROPERTIES + ["build_commit", "test_time", "avg_ts"])
139147

140148
def __init__(self):
141149
try:
@@ -146,6 +154,12 @@ def __init__(self):
146154
def _builds_init(self):
147155
self.build_len = self.build_len_min
148156

157+
def _check_keys(self, keys: set) -> Optional[set]:
158+
"""Private helper method that checks against required data keys and returns missing ones."""
159+
if not keys >= self.check_keys:
160+
return self.check_keys - keys
161+
return None
162+
149163
def find_parent_in_data(self, commit: git.Commit) -> Optional[str]:
150164
"""Helper method to find the most recent parent measured in number of commits for which there is data."""
151165
heap: list[tuple[int, git.Commit]] = [(0, commit)]
@@ -217,79 +231,117 @@ def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare
217231
return []
218232

219233

220-
class LlamaBenchDataGeneric(LlamaBenchData):
221-
data: list
222-
check_keys = set(KEY_PROPERTIES + ["build_commit", "test_time", "avg_ts"])
234+
class LlamaBenchDataSQLite3(LlamaBenchData):
235+
connection: sqlite3.Connection
236+
cursor: sqlite3.Cursor
223237

224238
def __init__(self):
225239
super().__init__()
226-
self.data = []
227-
228-
def _check_keys(self, keys: set) -> Optional[set]:
229-
if not keys >= self.check_keys:
230-
return self.check_keys - keys
231-
return None
240+
self.connection = sqlite3.connect(":memory:")
241+
self.cursor = self.connection.cursor()
242+
self.cursor.execute(f"CREATE TABLE test({', '.join(DB_FIELDS)});")
232243

233244
def _builds_init(self):
234-
self.build_len_min, self.build_len_max = reduce(lambda x, y: (min(x[0], y), max(x[1], y)), (len(d["build_commit"]) for d in self.data), (1000, 0))
235-
self.builds = list(set(d["build_commit"] for d in self.data))
245+
if self.connection:
246+
self.build_len_min = self.cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
247+
self.build_len_max = self.cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]
248+
249+
if self.build_len_min != self.build_len_max:
250+
logger.warning("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
251+
"Try purging the the database of old commits.")
252+
self.cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});")
253+
254+
builds = self.cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
255+
self.builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]
236256
super()._builds_init()
237257

238258
def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
239-
return sorted(((d["build_commit"], d["test_time"]) for d in self.data), key=lambda x: x[1], reverse=reverse)
259+
data = self.cursor.execute(
260+
"SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall()
261+
return reversed(data) if reverse else data
240262

241263
def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
242-
select_data = []
243-
join_equal = lambda x: tuple(x[p] for p in KEY_PROPERTIES) # noqa: E731
244-
group_order = lambda x: tuple(x[p] for p in properties + ["n_gen", "n_prompt", "n_depth"]) # noqa: E731
245-
for _, g in groupby(sorted(self.data, key=group_order), key=group_order):
246-
g = list(g)
247-
join_on = {}
248-
for row in filter(lambda x: x["build_commit"] == hexsha8_baseline, g):
249-
if (join_row := join_equal(row)) not in join_on:
250-
row_copy = row.copy()
251-
row_copy["avg_ts"] = []
252-
join_on[join_row] = row_copy
253-
join_on[join_row]["avg_ts"].append(row["avg_ts"])
254-
joined = {}
255-
for row in filter(lambda x: x["build_commit"] == hexsha8_compare, g):
256-
if (join_row := join_equal(row)) in join_on:
257-
joined.setdefault(join_row, join_on[join_row]).setdefault("tc.avg_ts", []).append(row["avg_ts"])
258-
for row in joined.values():
259-
select_data.append(tuple(row[p] for p in properties + ["n_prompt", "n_gen", "n_depth"]) + (fmean(row["avg_ts"]), fmean(row["tc.avg_ts"])))
260-
return select_data
261-
262-
263-
class LlamaBenchDataJSONL(LlamaBenchDataGeneric):
264+
select_string = ", ".join(
265+
[f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
266+
equal_string = " AND ".join(
267+
[f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [
268+
f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
269+
)
270+
group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
271+
query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} "
272+
f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
273+
return self.cursor.execute(query).fetchall()
274+
275+
276+
class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
277+
connected_file = False
278+
264279
def __init__(self, data_file: str):
265280
super().__init__()
266281

267-
with open(data_file, "r", encoding="utf-8") as fp:
268-
for i, line in enumerate(fp):
269-
parsed = json.loads(line)
270-
if (missing_keys := self._check_keys(parsed.keys())):
271-
raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
272-
self.data.append(parsed)
282+
connection = sqlite3.connect(data_file)
283+
cursor = connection.cursor()
273284

274-
self._builds_init()
285+
# Test if data_file is a valid SQLite database
286+
try:
287+
if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0:
288+
raise RuntimeError("The provided input file does not exist or is empty.")
289+
except sqlite3.DatabaseError:
290+
connection.close()
291+
connection = None
292+
293+
if (connection):
294+
self.connected_file = True
295+
self.connection = connection
296+
self.cursor = cursor
297+
self._builds_init()
298+
299+
300+
class LlamaBenchDataSQLite3_or_JSONL(LlamaBenchDataSQLite3File):
301+
def __init__(self, data_file: str):
302+
super().__init__(data_file)
303+
304+
if not self.connected_file:
305+
with open(data_file, "r", encoding="utf-8") as fp:
306+
for i, line in enumerate(fp):
307+
parsed = json.loads(line)
308+
309+
if "samples_ns" in parsed:
310+
del parsed["samples_ns"]
311+
if "samples_ts" in parsed:
312+
del parsed["samples_ts"]
275313

314+
if (missing_keys := self._check_keys(parsed.keys())):
315+
raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
316+
317+
self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
318+
319+
self._builds_init()
276320

277-
class LlamaBenchDataJSON(LlamaBenchDataGeneric):
321+
322+
class LlamaBenchDataJSON(LlamaBenchDataSQLite3):
278323
def __init__(self, data_files: list[str]):
279324
super().__init__()
280325

281326
for data_file in data_files:
282327
with open(data_file, "r", encoding="utf-8") as fp:
283328
parsed = json.load(fp)
329+
284330
for i, entry in enumerate(parsed):
331+
if "samples_ns" in entry:
332+
del entry["samples_ns"]
333+
if "samples_ts" in entry:
334+
del entry["samples_ts"]
335+
285336
if (missing_keys := self._check_keys(entry.keys())):
286337
raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}")
287-
self.data += parsed
338+
339+
self.cursor.execute(f"INSERT INTO test({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", tuple(entry.values()))
288340

289341
self._builds_init()
290342

291343

292-
class LlamaBenchDataCSV(LlamaBenchDataGeneric):
344+
class LlamaBenchDataCSV(LlamaBenchDataSQLite3):
293345
def __init__(self, data_files: list[str]):
294346
super().__init__()
295347

@@ -298,72 +350,14 @@ def __init__(self, data_files: list[str]):
298350
for i, parsed in enumerate(csv.DictReader(fp)):
299351
if (missing_keys := self._check_keys(set(parsed.keys()))):
300352
raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}")
353+
301354
# FIXME: Convert float/int columns from str!
302-
self.data.append(parsed)
355+
self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values()))
303356

304357
self._builds_init()
305358

306359

307-
class LlamaBenchDataSQLite3(LlamaBenchData):
308-
connection: Optional[sqlite3.Connection] = None
309-
cursor: sqlite3.Cursor
310-
311-
def __init__(self, data_file: str):
312-
super().__init__()
313-
314-
connection = sqlite3.connect(data_file)
315-
cursor = connection.cursor()
316-
317-
# Test if data_file is a valid SQLite database
318-
try:
319-
if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0:
320-
raise RuntimeError("The provided input file does not exist or is empty.")
321-
except sqlite3.DatabaseError:
322-
connection.close()
323-
connection = None
324-
325-
if (connection):
326-
self.connection = connection
327-
self.cursor = cursor
328-
329-
self.build_len_min = cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0]
330-
self.build_len_max = cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0]
331-
332-
if self.build_len_min != self.build_len_max:
333-
logger.warning(f"{data_file} contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. "
334-
"Try purging the the database of old commits.")
335-
cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});")
336-
337-
self._builds_init()
338-
339-
def _builds_init(self):
340-
if self.connection:
341-
builds = self.cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
342-
self.builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str]
343-
super()._builds_init()
344-
345-
def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]:
346-
data = self.cursor.execute(
347-
"SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall()
348-
return reversed(data) if reverse else data
349-
350-
def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]:
351-
select_string = ", ".join(
352-
[f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"])
353-
equal_string = " AND ".join(
354-
[f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [
355-
f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"]
356-
)
357-
group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"])
358-
query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} "
359-
f"GROUP BY {group_order_string} ORDER BY {group_order_string};")
360-
return self.cursor.execute(query).fetchall()
361-
362-
363-
bench_data = LlamaBenchDataSQLite3(input_file)
364-
if not bench_data.connection:
365-
# Not a SQLite database, try JSONL instead
366-
bench_data = LlamaBenchDataJSONL(input_file)
360+
bench_data = LlamaBenchDataSQLite3_or_JSONL(input_file)
367361

368362
if not bench_data.builds:
369363
raise RuntimeError(f"{input_file} does not contain any builds.")

0 commit comments

Comments
 (0)