@@ -315,28 +315,29 @@ def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare
315315
316316
317317class LlamaBenchDataSQLite3 (LlamaBenchData ):
318- connection : sqlite3 .Connection
318+ connection : Optional [ sqlite3 .Connection ] = None
319319 cursor : sqlite3 .Cursor
320320 table_name : str
321321
322322 def __init__ (self , tool : str = "llama-bench" ):
323323 super ().__init__ (tool )
324- self .connection = sqlite3 .connect (":memory:" )
325- self .cursor = self .connection .cursor ()
324+ if self .connection is None :
325+ self .connection = sqlite3 .connect (":memory:" )
326+ self .cursor = self .connection .cursor ()
326327
327- # Set table name and schema based on tool
328- if self .tool == "llama-bench" :
329- self .table_name = "llama_bench"
330- db_fields = LLAMA_BENCH_DB_FIELDS
331- db_types = LLAMA_BENCH_DB_TYPES
332- elif self .tool == "test-backend-ops" :
333- self .table_name = "test_backend_ops"
334- db_fields = TEST_BACKEND_OPS_DB_FIELDS
335- db_types = TEST_BACKEND_OPS_DB_TYPES
336- else :
337- assert False
328+ # Set table name and schema based on tool
329+ if self .tool == "llama-bench" :
330+ self .table_name = "llama_bench"
331+ db_fields = LLAMA_BENCH_DB_FIELDS
332+ db_types = LLAMA_BENCH_DB_TYPES
333+ elif self .tool == "test-backend-ops" :
334+ self .table_name = "test_backend_ops"
335+ db_fields = TEST_BACKEND_OPS_DB_FIELDS
336+ db_types = TEST_BACKEND_OPS_DB_TYPES
337+ else :
338+ assert False
338339
339- self .cursor .execute (f"CREATE TABLE { self .table_name } ({ ', ' .join (' ' .join (x ) for x in zip (db_fields , db_types ))} );" )
340+ self .cursor .execute (f"CREATE TABLE { self .table_name } ({ ', ' .join (' ' .join (x ) for x in zip (db_fields , db_types ))} );" )
340341
341342 def _builds_init (self ):
342343 if self .connection :
@@ -397,9 +398,6 @@ def _get_rows_test_backend_ops(self, properties: list[str], hexsha8_baseline: st
397398
398399class LlamaBenchDataSQLite3File (LlamaBenchDataSQLite3 ):
399400 def __init__ (self , data_file : str , tool : Any ):
400- super ().__init__ (tool )
401-
402- self .connection .close ()
403401 self .connection = sqlite3 .connect (data_file )
404402 self .cursor = self .connection .cursor ()
405403
@@ -411,27 +409,28 @@ def __init__(self, data_file: str, tool: Any):
411409 if tool is None :
412410 if "llama_bench" in table_names :
413411 self .table_name = "llama_bench"
414- self . tool = "llama-bench"
412+ tool = "llama-bench"
415413 elif "test_backend_ops" in table_names :
416414 self .table_name = "test_backend_ops"
417- self . tool = "test-backend-ops"
415+ tool = "test-backend-ops"
418416 else :
419417 raise RuntimeError (f"No suitable table found in database. Available tables: { table_names } " )
420418 elif tool == "llama-bench" :
421419 if "llama_bench" in table_names :
422420 self .table_name = "llama_bench"
423- self . tool = "llama-bench"
421+ tool = "llama-bench"
424422 else :
425423 raise RuntimeError (f"Table 'test' not found for tool 'llama-bench'. Available tables: { table_names } " )
426424 elif tool == "test-backend-ops" :
427425 if "test_backend_ops" in table_names :
428426 self .table_name = "test_backend_ops"
429- self . tool = "test-backend-ops"
427+ tool = "test-backend-ops"
430428 else :
431429 raise RuntimeError (f"Table 'test_backend_ops' not found for tool 'test-backend-ops'. Available tables: { table_names } " )
432430 else :
433431 raise RuntimeError (f"Unknown tool: { tool } " )
434432
433+ super ().__init__ (tool )
435434 self ._builds_init ()
436435
437436 @staticmethod
@@ -653,6 +652,8 @@ def get_flops_unit_name(flops_values: list) -> str:
653652if not bench_data .builds :
654653 raise RuntimeError (f"{ input_file } does not contain any builds." )
655654
655+ tool = bench_data .tool # May have chosen a default if tool was None.
656+
656657
657658hexsha8_baseline = name_baseline = None
658659
0 commit comments