Skip to content

Commit 7697ad5

Browse files
committed
allow for multiple calls to the benchmark fixture within a test
1 parent 1a53bed commit 7697ad5

File tree

3 files changed

+70
-74
lines changed

3 files changed

+70
-74
lines changed

codeflash/benchmarking/plugin/plugin.py

Lines changed: 37 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import sys
66
import time
77
from pathlib import Path
8+
from typing import Any, Callable
89

910
import pytest
1011

1112
from codeflash.benchmarking.codeflash_trace import codeflash_trace
13+
from codeflash.cli_cmds.cli import logger
1214
from codeflash.code_utils.code_utils import module_name_from_file_path
1315
from codeflash.models.models import BenchmarkKey
1416

@@ -22,7 +24,6 @@ def __init__(self) -> None:
2224

2325
def setup(self, trace_path: str, project_root: str) -> None:
2426
try:
25-
# Open connection
2627
self.project_root = project_root
2728
self._trace_path = trace_path
2829
self._connection = sqlite3.connect(self._trace_path)
@@ -35,12 +36,10 @@ def setup(self, trace_path: str, project_root: str) -> None:
3536
"benchmark_time_ns INTEGER)"
3637
)
3738
self._connection.commit()
38-
self.close() # Reopen only at the end of pytest session
39+
self.close()
3940
except Exception as e:
40-
print(f"Database setup error: {e}")
41-
if self._connection:
42-
self._connection.close()
43-
self._connection = None
41+
logger.error(f"Database setup error: {e}")
42+
self.close()
4443
raise
4544

4645
def write_benchmark_timings(self) -> None:
@@ -52,15 +51,14 @@ def write_benchmark_timings(self) -> None:
5251

5352
try:
5453
cur = self._connection.cursor()
55-
# Insert data into the benchmark_timings table
5654
cur.executemany(
5755
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
5856
self.benchmark_timings,
5957
)
6058
self._connection.commit()
61-
self.benchmark_timings = [] # Clear the benchmark timings list
59+
self.benchmark_timings.clear()
6260
except Exception as e:
63-
print(f"Error writing to benchmark timings database: {e}")
61+
logger.error(f"Error writing to benchmark timings database: {e}")
6462
self._connection.rollback()
6563
raise
6664

@@ -83,22 +81,18 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
8381
- Values are function timing in milliseconds
8482
8583
"""
86-
# Initialize the result dictionary
8784
result = {}
8885

89-
# Connect to the SQLite database
9086
connection = sqlite3.connect(trace_path)
9187
cursor = connection.cursor()
9288

9389
try:
94-
# Query the function_calls table for all function calls
9590
cursor.execute(
9691
"SELECT module_name, class_name, function_name, "
9792
"benchmark_module_path, benchmark_function_name, benchmark_line_number, function_time_ns "
9893
"FROM benchmark_function_timings"
9994
)
10095

101-
# Process each row
10296
for row in cursor.fetchall():
10397
module_name, class_name, function_name, benchmark_file, benchmark_func, benchmark_line, time_ns = row
10498

@@ -110,7 +104,6 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
110104

111105
# Create the benchmark key (file::function::line)
112106
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
113-
# Initialize the inner dictionary if needed
114107
if qualified_name not in result:
115108
result[qualified_name] = {}
116109

@@ -122,7 +115,6 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
122115
result[qualified_name][benchmark_key] = time_ns
123116

124117
finally:
125-
# Close the connection
126118
connection.close()
127119

128120
return result
@@ -140,11 +132,9 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
140132
- Values are total benchmark timing in milliseconds (with overhead subtracted)
141133
142134
"""
143-
# Initialize the result dictionary
144135
result = {}
145136
overhead_by_benchmark = {}
146137

147-
# Connect to the SQLite database
148138
connection = sqlite3.connect(trace_path)
149139
cursor = connection.cursor()
150140

@@ -156,7 +146,6 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
156146
"GROUP BY benchmark_module_path, benchmark_function_name, benchmark_line_number"
157147
)
158148

159-
# Process overhead information
160149
for row in cursor.fetchall():
161150
benchmark_file, benchmark_func, benchmark_line, total_overhead_ns = row
162151
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
@@ -168,52 +157,48 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
168157
"FROM benchmark_timings"
169158
)
170159

171-
# Process each row and subtract overhead
172160
for row in cursor.fetchall():
173161
benchmark_file, benchmark_func, benchmark_line, time_ns = row
174162

175-
# Create the benchmark key (file::function::line)
176-
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
163+
benchmark_key = BenchmarkKey(
164+
module_path=benchmark_file, function_name=benchmark_func
165+
) # (file::function::line)
177166
# Subtract overhead from total time
178167
overhead = overhead_by_benchmark.get(benchmark_key, 0)
179168
result[benchmark_key] = time_ns - overhead
180169

181170
finally:
182-
# Close the connection
183171
connection.close()
184172

185173
return result
186174

187-
# Pytest hooks
188175
@pytest.hookimpl
189-
def pytest_sessionfinish(self, session, exitstatus):
176+
def pytest_sessionfinish(self, session: pytest.Session, exitstatus: int) -> None: # noqa: ARG002
190177
"""Execute after whole test run is completed."""
191-
# Write any remaining benchmark timings to the database
192178
codeflash_trace.close()
193179
if self.benchmark_timings:
194180
self.write_benchmark_timings()
195-
# Close the database connection
196181
self.close()
197182

198183
@staticmethod
199-
def pytest_addoption(parser):
184+
def pytest_addoption(parser: pytest.Parser) -> None:
200185
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
201186

202187
@staticmethod
203-
def pytest_plugin_registered(plugin, manager):
188+
def pytest_plugin_registered(plugin: Any, manager: Any) -> None: # noqa: ANN401
204189
# Not necessary since run with -p no:benchmark, but just in case
205190
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
206191
manager.unregister(plugin)
207192

208193
@staticmethod
209-
def pytest_configure(config):
194+
def pytest_configure(config: pytest.Config) -> None:
210195
"""Register the benchmark marker."""
211196
config.addinivalue_line(
212197
"markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing"
213198
)
214199

215200
@staticmethod
216-
def pytest_collection_modifyitems(config, items):
201+
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
217202
# Skip tests that don't have the benchmark fixture
218203
if not config.getoption("--codeflash-trace"):
219204
return
@@ -236,54 +221,51 @@ def pytest_collection_modifyitems(config, items):
236221

237222
# Benchmark fixture
238223
class Benchmark:
239-
def __init__(self, request):
240-
self.request = request
241-
242-
def __call__(self, func, *args, **kwargs):
243-
"""Handle both direct function calls and decorator usage."""
244-
if args or kwargs:
245-
# Used as benchmark(func, *args, **kwargs)
246-
return self._run_benchmark(func, *args, **kwargs)
224+
"""Benchmark fixture class for running and timing benchmarked functions."""
247225

248-
# Used as @benchmark decorator
249-
def wrapped_func(*inner_args, **inner_kwargs):
250-
return self._run_benchmark(func, *inner_args, **inner_kwargs)
226+
def __init__(self, request: pytest.FixtureRequest) -> None:
227+
self.request = request
228+
self._call_count = 0
251229

252-
return wrapped_func
230+
def __call__(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: # noqa: ANN401
231+
benchmark_name_suffix = kwargs.pop("benchmark_name_suffix", None)
232+
return self._run_benchmark(func, args, kwargs, benchmark_name_suffix)
253233

254-
def _run_benchmark(self, func, *args, **kwargs):
255-
"""Actual benchmark implementation."""
234+
def _run_benchmark(
235+
self, func: Callable, args: tuple, kwargs: dict, benchmark_name_suffix: str | None = None
236+
) -> Any: # noqa: ANN401
256237
benchmark_module_path = module_name_from_file_path(
257238
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
258239
)
259240
benchmark_function_name = self.request.node.name
260241
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack
261-
# Set env vars
262-
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
242+
self._call_count += 1
243+
if benchmark_name_suffix:
244+
call_identifier = f"{benchmark_function_name}::{benchmark_name_suffix}"
245+
else:
246+
call_identifier = f"{benchmark_function_name}::call_{self._call_count}"
247+
248+
os.environ["CODEFLASH_BENCHMARKING"] = "True"
249+
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = call_identifier
263250
os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path
264251
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
265252
os.environ["CODEFLASH_BENCHMARKING"] = "True"
266-
# Run the function
267-
start = time.time_ns()
253+
start = time.perf_counter_ns()
268254
result = func(*args, **kwargs)
269-
end = time.time_ns()
270-
# Reset the environment variable
255+
end = time.perf_counter_ns()
271256
os.environ["CODEFLASH_BENCHMARKING"] = "False"
272257

273-
# Write function calls
274258
codeflash_trace.write_function_timings()
275-
# Reset function call count
276259
codeflash_trace.function_call_count = 0
277-
# Add to the benchmark timings buffer
278260
codeflash_benchmark_plugin.benchmark_timings.append(
279-
(benchmark_module_path, benchmark_function_name, line_number, end - start)
261+
(benchmark_module_path, call_identifier, line_number, end - start)
280262
)
281263

282264
return result
283265

284266
@staticmethod
285267
@pytest.fixture
286-
def benchmark(request):
268+
def benchmark(request: pytest.FixtureRequest) -> CodeFlashBenchmarkPlugin.Benchmark | None:
287269
if not request.config.getoption("--codeflash-trace"):
288270
return None
289271

codeflash/benchmarking/utils.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from __future__ import annotations
22

3-
import shutil
43
from typing import TYPE_CHECKING, Optional
54

6-
from rich.console import Console
75
from rich.table import Table
86

9-
from codeflash.cli_cmds.console import logger
7+
from codeflash.cli_cmds.console import console, logger
108
from codeflash.code_utils.time_utils import humanize_runtime
119
from codeflash.models.models import BenchmarkDetail, ProcessedBenchmarkInfo
1210
from codeflash.result.critic import performance_gain
@@ -42,36 +40,51 @@ def validate_and_format_benchmark_table(
4240

4341

4442
def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey, float, float, float]]]) -> None:
45-
try:
46-
terminal_width = int(shutil.get_terminal_size().columns * 0.9)
47-
except Exception:
48-
terminal_width = 120 # Fallback width
49-
console = Console(width=terminal_width)
5043
for func_path, sorted_tests in function_to_results.items():
5144
console.print()
5245
function_name = func_path.split(":")[-1]
5346

54-
# Create a table for this function
55-
table = Table(title=f"Function: {function_name}", width=terminal_width, border_style="blue", show_lines=True)
56-
benchmark_col_width = max(int(terminal_width * 0.4), 40)
57-
# Add columns - split the benchmark test into two columns
58-
table.add_column("Benchmark Module Path", width=benchmark_col_width, style="cyan", overflow="fold")
47+
table = Table(title=f"Function: {function_name}", border_style="blue", show_lines=True)
48+
table.add_column("Benchmark Module Path", style="cyan", overflow="fold")
5949
table.add_column("Test Function", style="magenta", overflow="fold")
6050
table.add_column("Total Time (ms)", justify="right", style="green")
6151
table.add_column("Function Time (ms)", justify="right", style="yellow")
6252
table.add_column("Percentage (%)", justify="right", style="red")
6353

64-
for benchmark_key, total_time, func_time, percentage in sorted_tests:
65-
# Split the benchmark test into module path and function name
66-
module_path = benchmark_key.module_path
54+
multi_call_bases = set()
55+
call_1_tests = []
56+
57+
for i, (benchmark_key, _, _, _) in enumerate(sorted_tests):
6758
test_function = benchmark_key.function_name
59+
module_path = benchmark_key.module_path
60+
if "::call_" in test_function:
61+
try:
62+
base_name, call_part = test_function.rsplit("::call_", 1)
63+
call_num = int(call_part)
64+
if call_num == 1:
65+
call_1_tests.append((i, base_name, module_path))
66+
elif call_num > 1:
67+
multi_call_bases.add((base_name, module_path))
68+
except ValueError:
69+
pass
70+
71+
tests_to_modify = {
72+
index: base_name
73+
for index, base_name, module_path in call_1_tests
74+
if (base_name, module_path) not in multi_call_bases
75+
}
76+
77+
for i, (benchmark_key, total_time, func_time, percentage) in enumerate(sorted_tests):
78+
module_path = benchmark_key.module_path
79+
test_function_display = tests_to_modify.get(i, benchmark_key.function_name)
6880

6981
if total_time == 0.0:
70-
table.add_row(module_path, test_function, "N/A", "N/A", "N/A")
82+
table.add_row(module_path, test_function_display, "N/A", "N/A", "N/A")
7183
else:
72-
table.add_row(module_path, test_function, f"{total_time:.3f}", f"{func_time:.3f}", f"{percentage:.2f}")
84+
table.add_row(
85+
module_path, test_function_display, f"{total_time:.3f}", f"{func_time:.3f}", f"{percentage:.2f}"
86+
)
7387

74-
# Print the table
7588
console.print(table)
7689

7790

codeflash/discovery/functions_to_optimize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def get_functions_to_optimize(
201201
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
202202
)
203203
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
204+
console.rule()
204205
return filtered_modified_functions, functions_count
205206

206207

0 commit comments

Comments
 (0)