Skip to content

Commit 9a41bdd

Browse files
committed
shifted benchmark class in plugin, improved display of benchmark info
1 parent 0c2a3b6 commit 9a41bdd

File tree

2 files changed

+49
-32
lines changed

2 files changed

+49
-32
lines changed

codeflash/benchmarking/plugin/plugin.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,29 @@
55
import os
66
class CodeFlashBenchmarkPlugin:
77
benchmark_timings = []
8+
9+
class Benchmark:
10+
def __init__(self, request):
11+
self.request = request
12+
13+
def __call__(self, func, *args, **kwargs):
14+
benchmark_file_name = self.request.node.fspath.basename
15+
benchmark_function_name = self.request.node.name
16+
line_number = str(sys._getframe(1).f_lineno) # 1 frame up in the call stack
17+
18+
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
19+
os.environ["CODEFLASH_BENCHMARK_FILE_NAME"] = benchmark_file_name
20+
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = line_number
21+
os.environ["CODEFLASH_BENCHMARKING"] = "True"
22+
23+
start = time.perf_counter_ns()
24+
result = func(*args, **kwargs)
25+
end = time.perf_counter_ns()
26+
27+
os.environ["CODEFLASH_BENCHMARKING"] = "False"
28+
CodeFlashBenchmarkPlugin.benchmark_timings.append(
29+
(benchmark_file_name, benchmark_function_name, line_number, end - start))
30+
return result
831
@staticmethod
932
def pytest_addoption(parser):
1033
parser.addoption(
@@ -36,23 +59,4 @@ def benchmark(request):
3659
if not request.config.getoption("--codeflash-trace"):
3760
return None
3861

39-
class Benchmark:
40-
41-
def __call__(self, func, *args, **kwargs):
42-
benchmark_file_name = request.node.fspath.basename
43-
benchmark_function_name = request.node.name
44-
line_number = str(sys._getframe(1).f_lineno) # 1 frame up in the call stack
45-
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
46-
os.environ["CODEFLASH_BENCHMARK_FILE_NAME"] = benchmark_file_name
47-
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = line_number
48-
os.environ["CODEFLASH_BENCHMARKING"] = "True"
49-
50-
start = time.perf_counter_ns()
51-
result = func(*args, **kwargs)
52-
end = time.perf_counter_ns()
53-
54-
os.environ["CODEFLASH_BENCHMARKING"] = "False"
55-
CodeFlashBenchmarkPlugin.benchmark_timings.append((benchmark_file_name, benchmark_function_name, line_number, end - start))
56-
return result
57-
58-
return Benchmark()
62+
return CodeFlashBenchmarkPlugin.Benchmark(request)

codeflash/benchmarking/utils.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1-
def print_benchmark_table(function_benchmark_timings: dict[str,dict[str,int]], total_benchmark_timings: dict[str,int]):
2-
# Define column widths
3-
benchmark_col_width = 50
4-
time_col_width = 15
1+
from rich.console import Console
2+
from rich.table import Table
53

6-
# Print table header
7-
header = f"{'Benchmark Test':{benchmark_col_width}} | {'Total Time (ms)':{time_col_width}} | {'Function Time (ms)':{time_col_width}} | {'Percentage (%)':{time_col_width}}"
8-
print(header)
9-
print("-" * len(header))
4+
5+
def print_benchmark_table(function_benchmark_timings: dict[str, dict[str, int]],
6+
total_benchmark_timings: dict[str, int]):
7+
console = Console()
108

119
# Process each function's benchmark data
1210
for func_path, test_times in function_benchmark_timings.items():
1311
function_name = func_path.split(":")[-1]
14-
print(f"\n== Function: {function_name} ==")
12+
13+
# Create a table for this function
14+
table = Table(title=f"Function: {function_name}", border_style="blue")
15+
16+
# Add columns
17+
table.add_column("Benchmark Test", style="cyan", no_wrap=True)
18+
table.add_column("Total Time (ms)", justify="right", style="green")
19+
table.add_column("Function Time (ms)", justify="right", style="yellow")
20+
table.add_column("Percentage (%)", justify="right", style="red")
1521

1622
# Sort by percentage (highest first)
1723
sorted_tests = []
@@ -26,9 +32,16 @@ def print_benchmark_table(function_benchmark_timings: dict[str,dict[str,int]], t
2632

2733
sorted_tests.sort(key=lambda x: x[3], reverse=True)
2834

29-
# Print each test's data
35+
# Add rows to the table
3036
for test_name, total_time, func_time, percentage in sorted_tests:
3137
benchmark_file, benchmark_func, benchmark_line = test_name.split("::")
3238
benchmark_name = f"{benchmark_file}::{benchmark_func}"
33-
print(f"{benchmark_name:{benchmark_col_width}} | {total_time:{time_col_width}.3f} | {func_time:{time_col_width}.3f} | {percentage:{time_col_width}.2f}")
34-
print()
39+
table.add_row(
40+
benchmark_name,
41+
f"{total_time:.3f}",
42+
f"{func_time:.3f}",
43+
f"{percentage:.2f}"
44+
)
45+
46+
# Print the table
47+
console.print(table)

0 commit comments

Comments
 (0)