Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions scripts/compare-llama-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,28 +315,29 @@ def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare


class LlamaBenchDataSQLite3(LlamaBenchData):
connection: sqlite3.Connection
connection: Optional[sqlite3.Connection] = None
cursor: sqlite3.Cursor
table_name: str

def __init__(self, tool: str = "llama-bench"):
super().__init__(tool)
self.connection = sqlite3.connect(":memory:")
self.cursor = self.connection.cursor()
if self.connection is None:
self.connection = sqlite3.connect(":memory:")
self.cursor = self.connection.cursor()

# Set table name and schema based on tool
if self.tool == "llama-bench":
self.table_name = "llama_bench"
db_fields = LLAMA_BENCH_DB_FIELDS
db_types = LLAMA_BENCH_DB_TYPES
elif self.tool == "test-backend-ops":
self.table_name = "test_backend_ops"
db_fields = TEST_BACKEND_OPS_DB_FIELDS
db_types = TEST_BACKEND_OPS_DB_TYPES
else:
assert False
# Set table name and schema based on tool
if self.tool == "llama-bench":
self.table_name = "llama_bench"
db_fields = LLAMA_BENCH_DB_FIELDS
db_types = LLAMA_BENCH_DB_TYPES
elif self.tool == "test-backend-ops":
self.table_name = "test_backend_ops"
db_fields = TEST_BACKEND_OPS_DB_FIELDS
db_types = TEST_BACKEND_OPS_DB_TYPES
else:
assert False

self.cursor.execute(f"CREATE TABLE {self.table_name}({', '.join(' '.join(x) for x in zip(db_fields, db_types))});")
self.cursor.execute(f"CREATE TABLE {self.table_name}({', '.join(' '.join(x) for x in zip(db_fields, db_types))});")

def _builds_init(self):
if self.connection:
Expand Down Expand Up @@ -397,9 +398,6 @@ def _get_rows_test_backend_ops(self, properties: list[str], hexsha8_baseline: st

class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3):
def __init__(self, data_file: str, tool: Any):
super().__init__(tool)

self.connection.close()
self.connection = sqlite3.connect(data_file)
self.cursor = self.connection.cursor()

Expand All @@ -411,27 +409,28 @@ def __init__(self, data_file: str, tool: Any):
if tool is None:
if "llama_bench" in table_names:
self.table_name = "llama_bench"
self.tool = "llama-bench"
tool = "llama-bench"
elif "test_backend_ops" in table_names:
self.table_name = "test_backend_ops"
self.tool = "test-backend-ops"
tool = "test-backend-ops"
else:
raise RuntimeError(f"No suitable table found in database. Available tables: {table_names}")
elif tool == "llama-bench":
if "llama_bench" in table_names:
self.table_name = "llama_bench"
self.tool = "llama-bench"
tool = "llama-bench"
else:
raise RuntimeError(f"Table 'test' not found for tool 'llama-bench'. Available tables: {table_names}")
elif tool == "test-backend-ops":
if "test_backend_ops" in table_names:
self.table_name = "test_backend_ops"
self.tool = "test-backend-ops"
tool = "test-backend-ops"
else:
raise RuntimeError(f"Table 'test_backend_ops' not found for tool 'test-backend-ops'. Available tables: {table_names}")
else:
raise RuntimeError(f"Unknown tool: {tool}")

super().__init__(tool)
self._builds_init()

@staticmethod
Expand Down Expand Up @@ -653,6 +652,8 @@ def get_flops_unit_name(flops_values: list) -> str:
if not bench_data.builds:
raise RuntimeError(f"{input_file} does not contain any builds.")

tool = bench_data.tool # May have chosen a default if tool was None.


hexsha8_baseline = name_baseline = None

Expand Down