@@ -327,10 +327,12 @@ def __init__(self, tool: str = "llama-bench"):
327327 self .table_name = "test"
328328 db_fields = LLAMA_BENCH_DB_FIELDS
329329 db_types = LLAMA_BENCH_DB_TYPES
330- else : # test-backend-ops
330+ elif self . tool == " test-backend-ops" :
331331 self .table_name = "test_backend_ops"
332332 db_fields = TEST_BACKEND_OPS_DB_FIELDS
333333 db_types = TEST_BACKEND_OPS_DB_TYPES
334+ else :
335+ assert False
334336
335337 self .cursor .execute (f"CREATE TABLE { self .table_name } ({ ', ' .join (' ' .join (x ) for x in zip (db_fields , db_types ))} );" )
336338
@@ -356,8 +358,10 @@ def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequ
356358 def get_rows (self , properties : list [str ], hexsha8_baseline : str , hexsha8_compare : str ) -> Sequence [tuple ]:
357359 if self .tool == "llama-bench" :
358360 return self ._get_rows_llama_bench (properties , hexsha8_baseline , hexsha8_compare )
359- else : # test-backend-ops
361+ elif self . tool == " test-backend-ops" :
360362 return self ._get_rows_test_backend_ops (properties , hexsha8_baseline , hexsha8_compare )
363+ else :
364+ assert False
361365
362366 def _get_rows_llama_bench (self , properties : list [str ], hexsha8_baseline : str , hexsha8_compare : str ) -> Sequence [tuple ]:
363367 select_string = ", " .join (
@@ -1041,8 +1045,10 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
10411045 # Determine y-axis label based on tool type
10421046 if tool_type == "llama-bench" :
10431047 y_label = "Tokens per second (t/s)"
1044- else : # test-backend-ops
1048+ elif tool_type == " test-backend-ops" :
10451049 y_label = metric_name
1050+ else :
1051+ assert False
10461052
10471053 ax .set_xlabel (plot_x_label , fontsize = 12 , fontweight = 'bold' )
10481054 ax .set_ylabel (y_label , fontsize = 12 , fontweight = 'bold' )
0 commit comments