55import sys
66import time
77from pathlib import Path
8+ from typing import Any , Callable
89
910import pytest
1011
1112from codeflash .benchmarking .codeflash_trace import codeflash_trace
13+ from codeflash .cli_cmds .cli import logger
1214from codeflash .code_utils .code_utils import module_name_from_file_path
1315from codeflash .models .models import BenchmarkKey
1416
@@ -22,7 +24,6 @@ def __init__(self) -> None:
2224
2325 def setup (self , trace_path : str , project_root : str ) -> None :
2426 try :
25- # Open connection
2627 self .project_root = project_root
2728 self ._trace_path = trace_path
2829 self ._connection = sqlite3 .connect (self ._trace_path )
@@ -35,12 +36,10 @@ def setup(self, trace_path: str, project_root: str) -> None:
3536 "benchmark_time_ns INTEGER)"
3637 )
3738 self ._connection .commit ()
38- self .close () # Reopen only at the end of pytest session
39+ self .close ()
3940 except Exception as e :
40- print (f"Database setup error: { e } " )
41- if self ._connection :
42- self ._connection .close ()
43- self ._connection = None
41+ logger .error (f"Database setup error: { e } " )
42+ self .close ()
4443 raise
4544
4645 def write_benchmark_timings (self ) -> None :
@@ -52,15 +51,14 @@ def write_benchmark_timings(self) -> None:
5251
5352 try :
5453 cur = self ._connection .cursor ()
55- # Insert data into the benchmark_timings table
5654 cur .executemany (
5755 "INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)" ,
5856 self .benchmark_timings ,
5957 )
6058 self ._connection .commit ()
61- self .benchmark_timings = [] # Clear the benchmark timings list
59+ self .benchmark_timings . clear ()
6260 except Exception as e :
63- print (f"Error writing to benchmark timings database: { e } " )
61+ logger . error (f"Error writing to benchmark timings database: { e } " )
6462 self ._connection .rollback ()
6563 raise
6664
@@ -83,22 +81,18 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
8381 - Values are function timing in milliseconds
8482
8583 """
86- # Initialize the result dictionary
8784 result = {}
8885
89- # Connect to the SQLite database
9086 connection = sqlite3 .connect (trace_path )
9187 cursor = connection .cursor ()
9288
9389 try :
94- # Query the function_calls table for all function calls
9590 cursor .execute (
9691 "SELECT module_name, class_name, function_name, "
9792 "benchmark_module_path, benchmark_function_name, benchmark_line_number, function_time_ns "
9893 "FROM benchmark_function_timings"
9994 )
10095
101- # Process each row
10296 for row in cursor .fetchall ():
10397 module_name , class_name , function_name , benchmark_file , benchmark_func , benchmark_line , time_ns = row
10498
@@ -110,7 +104,6 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
110104
111105 # Create the benchmark key (file::function::line)
112106 benchmark_key = BenchmarkKey (module_path = benchmark_file , function_name = benchmark_func )
113- # Initialize the inner dictionary if needed
114107 if qualified_name not in result :
115108 result [qualified_name ] = {}
116109
@@ -122,7 +115,6 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
122115 result [qualified_name ][benchmark_key ] = time_ns
123116
124117 finally :
125- # Close the connection
126118 connection .close ()
127119
128120 return result
@@ -140,11 +132,9 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
140132 - Values are total benchmark timing in milliseconds (with overhead subtracted)
141133
142134 """
143- # Initialize the result dictionary
144135 result = {}
145136 overhead_by_benchmark = {}
146137
147- # Connect to the SQLite database
148138 connection = sqlite3 .connect (trace_path )
149139 cursor = connection .cursor ()
150140
@@ -156,7 +146,6 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
156146 "GROUP BY benchmark_module_path, benchmark_function_name, benchmark_line_number"
157147 )
158148
159- # Process overhead information
160149 for row in cursor .fetchall ():
161150 benchmark_file , benchmark_func , benchmark_line , total_overhead_ns = row
162151 benchmark_key = BenchmarkKey (module_path = benchmark_file , function_name = benchmark_func )
@@ -168,52 +157,48 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
168157 "FROM benchmark_timings"
169158 )
170159
171- # Process each row and subtract overhead
172160 for row in cursor .fetchall ():
173161 benchmark_file , benchmark_func , benchmark_line , time_ns = row
174162
175- # Create the benchmark key (file::function::line)
176- benchmark_key = BenchmarkKey (module_path = benchmark_file , function_name = benchmark_func )
163+ benchmark_key = BenchmarkKey (
164+ module_path = benchmark_file , function_name = benchmark_func
165+ ) # (file::function::line)
177166 # Subtract overhead from total time
178167 overhead = overhead_by_benchmark .get (benchmark_key , 0 )
179168 result [benchmark_key ] = time_ns - overhead
180169
181170 finally :
182- # Close the connection
183171 connection .close ()
184172
185173 return result
186174
187- # Pytest hooks
188175 @pytest .hookimpl
189- def pytest_sessionfinish (self , session , exitstatus ):
176+ def pytest_sessionfinish (self , session : pytest . Session , exitstatus : int ) -> None : # noqa: ARG002
190177 """Execute after whole test run is completed."""
191- # Write any remaining benchmark timings to the database
192178 codeflash_trace .close ()
193179 if self .benchmark_timings :
194180 self .write_benchmark_timings ()
195- # Close the database connection
196181 self .close ()
197182
198183 @staticmethod
199- def pytest_addoption (parser ) :
184+ def pytest_addoption (parser : pytest . Parser ) -> None :
200185 parser .addoption ("--codeflash-trace" , action = "store_true" , default = False , help = "Enable CodeFlash tracing" )
201186
202187 @staticmethod
203- def pytest_plugin_registered (plugin , manager ):
188+ def pytest_plugin_registered (plugin : Any , manager : Any ) -> None : # noqa: ANN401
204189 # Not necessary since run with -p no:benchmark, but just in case
205190 if hasattr (plugin , "name" ) and plugin .name == "pytest-benchmark" :
206191 manager .unregister (plugin )
207192
208193 @staticmethod
209- def pytest_configure (config ) :
194+ def pytest_configure (config : pytest . Config ) -> None :
210195 """Register the benchmark marker."""
211196 config .addinivalue_line (
212197 "markers" , "benchmark: mark test as a benchmark that should be run with codeflash tracing"
213198 )
214199
215200 @staticmethod
216- def pytest_collection_modifyitems (config , items ) :
201+ def pytest_collection_modifyitems (config : pytest . Config , items : list [ pytest . Item ]) -> None :
217202 # Skip tests that don't have the benchmark fixture
218203 if not config .getoption ("--codeflash-trace" ):
219204 return
@@ -236,54 +221,51 @@ def pytest_collection_modifyitems(config, items):
236221
237222 # Benchmark fixture
238223 class Benchmark :
239- def __init__ (self , request ):
240- self .request = request
241-
242- def __call__ (self , func , * args , ** kwargs ):
243- """Handle both direct function calls and decorator usage."""
244- if args or kwargs :
245- # Used as benchmark(func, *args, **kwargs)
246- return self ._run_benchmark (func , * args , ** kwargs )
224+ """Benchmark fixture class for running and timing benchmarked functions."""
247225
248- # Used as @benchmark decorator
249- def wrapped_func ( * inner_args , ** inner_kwargs ):
250- return self ._run_benchmark ( func , * inner_args , ** inner_kwargs )
226+ def __init__ ( self , request : pytest . FixtureRequest ) -> None :
227+ self . request = request
228+ self ._call_count = 0
251229
252- return wrapped_func
230+ def __call__ (self , func : Callable [..., Any ], * args : Any , ** kwargs : Any ) -> Any : # noqa: ANN401
231+ benchmark_name_suffix = kwargs .pop ("benchmark_name_suffix" , None )
232+ return self ._run_benchmark (func , args , kwargs , benchmark_name_suffix )
253233
254- def _run_benchmark (self , func , * args , ** kwargs ):
255- """Actual benchmark implementation."""
234+ def _run_benchmark (
235+ self , func : Callable , args : tuple , kwargs : dict , benchmark_name_suffix : str | None = None
236+ ) -> Any : # noqa: ANN401
256237 benchmark_module_path = module_name_from_file_path (
257238 Path (str (self .request .node .fspath )), Path (codeflash_benchmark_plugin .project_root )
258239 )
259240 benchmark_function_name = self .request .node .name
260241 line_number = int (str (sys ._getframe (2 ).f_lineno )) # 2 frames up in the call stack
261- # Set env vars
262- os .environ ["CODEFLASH_BENCHMARK_FUNCTION_NAME" ] = benchmark_function_name
242+ self ._call_count += 1
243+ if benchmark_name_suffix :
244+ call_identifier = f"{ benchmark_function_name } ::{ benchmark_name_suffix } "
245+ else :
246+ call_identifier = f"{ benchmark_function_name } ::call_{ self ._call_count } "
247+
248+ os .environ ["CODEFLASH_BENCHMARKING" ] = "True"
249+ os .environ ["CODEFLASH_BENCHMARK_FUNCTION_NAME" ] = call_identifier
263250 os .environ ["CODEFLASH_BENCHMARK_MODULE_PATH" ] = benchmark_module_path
264251 os .environ ["CODEFLASH_BENCHMARK_LINE_NUMBER" ] = str (line_number )
265252 os .environ ["CODEFLASH_BENCHMARKING" ] = "True"
266- # Run the function
267- start = time .time_ns ()
253+ start = time .perf_counter_ns ()
268254 result = func (* args , ** kwargs )
269- end = time .time_ns ()
270- # Reset the environment variable
255+ end = time .perf_counter_ns ()
271256 os .environ ["CODEFLASH_BENCHMARKING" ] = "False"
272257
273- # Write function calls
274258 codeflash_trace .write_function_timings ()
275- # Reset function call count
276259 codeflash_trace .function_call_count = 0
277- # Add to the benchmark timings buffer
278260 codeflash_benchmark_plugin .benchmark_timings .append (
279- (benchmark_module_path , benchmark_function_name , line_number , end - start )
261+ (benchmark_module_path , call_identifier , line_number , end - start )
280262 )
281263
282264 return result
283265
284266 @staticmethod
285267 @pytest .fixture
286- def benchmark (request ) :
268+ def benchmark (request : pytest . FixtureRequest ) -> CodeFlashBenchmarkPlugin . Benchmark | None :
287269 if not request .config .getoption ("--codeflash-trace" ):
288270 return None
289271
0 commit comments