Skip to content

Commit 3915a8d

Browse files
committed
fix tests
1 parent deeaecf commit 3915a8d

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

scripts/compare-llama-bench.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100
2020
raise e
2121

22+
2223
logger = logging.getLogger("compare-llama-bench")
2324

2425
# All llama-bench SQL fields
@@ -129,14 +130,6 @@
129130

130131
logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
131132

132-
if known_args.plot:
133-
try:
134-
import matplotlib.pyplot as plt
135-
import matplotlib
136-
matplotlib.use('Agg')
137-
except ImportError as e:
138-
logger.error("matplotlib is required for --plot.")
139-
raise e
140133

141134
if known_args.check:
142135
# Check if all required Python libraries are installed. Would have failed earlier if not.
@@ -620,6 +613,13 @@ def valid_format(data_files: list[str]) -> bool:
620613

621614
if known_args.plot:
622615
def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str):
616+
try:
617+
import matplotlib.pyplot as plt
618+
import matplotlib
619+
matplotlib.use('Agg')
620+
except ImportError as e:
621+
logger.error("matplotlib is required for --plot.")
622+
raise e
623623

624624
data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
625625
plot_x_index = None
@@ -643,6 +643,9 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
643643
group_key_parts = []
644644
test_name = row[-4]
645645

646+
base_test = ""
647+
x_value = None
648+
646649
if plot_x_param in ["n_prompt", "n_gen", "n_depth"]:
647650
for j, val in enumerate(row[:-4]):
648651
header_name = data_headers[j]

0 commit comments

Comments
 (0)