@@ -315,28 +315,29 @@ def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare
315
315
316
316
317
317
class LlamaBenchDataSQLite3 (LlamaBenchData ):
318
- connection : sqlite3 .Connection
318
+ connection : Optional [ sqlite3 .Connection ] = None
319
319
cursor : sqlite3 .Cursor
320
320
table_name : str
321
321
322
322
def __init__ (self , tool : str = "llama-bench" ):
323
323
super ().__init__ (tool )
324
- self .connection = sqlite3 .connect (":memory:" )
325
- self .cursor = self .connection .cursor ()
324
+ if self .connection is None :
325
+ self .connection = sqlite3 .connect (":memory:" )
326
+ self .cursor = self .connection .cursor ()
326
327
327
- # Set table name and schema based on tool
328
- if self .tool == "llama-bench" :
329
- self .table_name = "llama_bench"
330
- db_fields = LLAMA_BENCH_DB_FIELDS
331
- db_types = LLAMA_BENCH_DB_TYPES
332
- elif self .tool == "test-backend-ops" :
333
- self .table_name = "test_backend_ops"
334
- db_fields = TEST_BACKEND_OPS_DB_FIELDS
335
- db_types = TEST_BACKEND_OPS_DB_TYPES
336
- else :
337
- assert False
328
+ # Set table name and schema based on tool
329
+ if self .tool == "llama-bench" :
330
+ self .table_name = "llama_bench"
331
+ db_fields = LLAMA_BENCH_DB_FIELDS
332
+ db_types = LLAMA_BENCH_DB_TYPES
333
+ elif self .tool == "test-backend-ops" :
334
+ self .table_name = "test_backend_ops"
335
+ db_fields = TEST_BACKEND_OPS_DB_FIELDS
336
+ db_types = TEST_BACKEND_OPS_DB_TYPES
337
+ else :
338
+ assert False
338
339
339
- self .cursor .execute (f"CREATE TABLE { self .table_name } ({ ', ' .join (' ' .join (x ) for x in zip (db_fields , db_types ))} );" )
340
+ self .cursor .execute (f"CREATE TABLE { self .table_name } ({ ', ' .join (' ' .join (x ) for x in zip (db_fields , db_types ))} );" )
340
341
341
342
def _builds_init (self ):
342
343
if self .connection :
@@ -397,9 +398,6 @@ def _get_rows_test_backend_ops(self, properties: list[str], hexsha8_baseline: st
397
398
398
399
class LlamaBenchDataSQLite3File (LlamaBenchDataSQLite3 ):
399
400
def __init__ (self , data_file : str , tool : Any ):
400
- super ().__init__ (tool )
401
-
402
- self .connection .close ()
403
401
self .connection = sqlite3 .connect (data_file )
404
402
self .cursor = self .connection .cursor ()
405
403
@@ -411,27 +409,28 @@ def __init__(self, data_file: str, tool: Any):
411
409
if tool is None :
412
410
if "llama_bench" in table_names :
413
411
self .table_name = "llama_bench"
414
- self . tool = "llama-bench"
412
+ tool = "llama-bench"
415
413
elif "test_backend_ops" in table_names :
416
414
self .table_name = "test_backend_ops"
417
- self . tool = "test-backend-ops"
415
+ tool = "test-backend-ops"
418
416
else :
419
417
raise RuntimeError (f"No suitable table found in database. Available tables: { table_names } " )
420
418
elif tool == "llama-bench" :
421
419
if "llama_bench" in table_names :
422
420
self .table_name = "llama_bench"
423
- self . tool = "llama-bench"
421
+ tool = "llama-bench"
424
422
else :
425
423
raise RuntimeError (f"Table 'test' not found for tool 'llama-bench'. Available tables: { table_names } " )
426
424
elif tool == "test-backend-ops" :
427
425
if "test_backend_ops" in table_names :
428
426
self .table_name = "test_backend_ops"
429
- self . tool = "test-backend-ops"
427
+ tool = "test-backend-ops"
430
428
else :
431
429
raise RuntimeError (f"Table 'test_backend_ops' not found for tool 'test-backend-ops'. Available tables: { table_names } " )
432
430
else :
433
431
raise RuntimeError (f"Unknown tool: { tool } " )
434
432
433
+ super ().__init__ (tool )
435
434
self ._builds_init ()
436
435
437
436
@staticmethod
@@ -653,6 +652,8 @@ def get_flops_unit_name(flops_values: list) -> str:
653
652
if not bench_data .builds :
654
653
raise RuntimeError (f"{ input_file } does not contain any builds." )
655
654
655
+ tool = bench_data .tool # May have chosen a default if tool was None.
656
+
656
657
657
658
hexsha8_baseline = name_baseline = None
658
659
0 commit comments