diff --git a/codeflash/benchmarking/plugin/plugin.py b/codeflash/benchmarking/plugin/plugin.py index 313817041..994600128 100644 --- a/codeflash/benchmarking/plugin/plugin.py +++ b/codeflash/benchmarking/plugin/plugin.py @@ -20,7 +20,7 @@ def __init__(self) -> None: self.project_root = None self.benchmark_timings = [] - def setup(self, trace_path:str, project_root:str) -> None: + def setup(self, trace_path: str, project_root: str) -> None: try: # Open connection self.project_root = project_root @@ -35,7 +35,7 @@ def setup(self, trace_path:str, project_root:str) -> None: "benchmark_time_ns INTEGER)" ) self._connection.commit() - self.close() # Reopen only at the end of pytest session + self.close() # Reopen only at the end of pytest session except Exception as e: print(f"Database setup error: {e}") if self._connection: @@ -55,14 +55,15 @@ def write_benchmark_timings(self) -> None: # Insert data into the benchmark_timings table cur.executemany( "INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)", - self.benchmark_timings + self.benchmark_timings, ) self._connection.commit() - self.benchmark_timings = [] # Clear the benchmark timings list + self.benchmark_timings = [] # Clear the benchmark timings list except Exception as e: print(f"Error writing to benchmark timings database: {e}") self._connection.rollback() raise + def close(self) -> None: if self._connection: self._connection.close() @@ -196,12 +197,7 @@ def pytest_sessionfinish(self, session, exitstatus): @staticmethod def pytest_addoption(parser): - parser.addoption( - "--codeflash-trace", - action="store_true", - default=False, - help="Enable CodeFlash tracing" - ) + parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing") @staticmethod def pytest_plugin_registered(plugin, manager): @@ -213,9 +209,9 @@ def pytest_plugin_registered(plugin, manager): def pytest_configure(config): """Register the benchmark marker.""" config.addinivalue_line( - "markers", - "benchmark: mark test as a benchmark that should be run with codeflash tracing" + "markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing" ) + @staticmethod def pytest_collection_modifyitems(config, items): # Skip tests that don't have the benchmark fixture @@ -224,19 +220,18 @@ def pytest_collection_modifyitems(config, items): skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture") for item in items: - # Check for direct benchmark fixture usage - has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames + if hasattr(item, "fixturenames"): + if "benchmark" in item.fixturenames: + continue # Skip current iteration if benchmark fixture is present # Check for @pytest.mark.benchmark marker - has_marker = False if hasattr(item, "get_closest_marker"): marker = item.get_closest_marker("benchmark") if marker is not None: - has_marker = True + continue # Skip current iteration if benchmark marker is present # Skip if neither fixture nor marker is present - if not (has_fixture or has_marker): - item.add_marker(skip_no_benchmark) + item.add_marker(skip_no_benchmark) # Benchmark fixture class Benchmark: @@ -248,16 +243,19 @@ def __call__(self, func, *args, **kwargs): if args or kwargs: # Used as benchmark(func, *args, **kwargs) return self._run_benchmark(func, *args, **kwargs) + # Used as @benchmark decorator def wrapped_func(*args, **kwargs): return func(*args, **kwargs) + result = self._run_benchmark(func) return wrapped_func def _run_benchmark(self, func, *args, **kwargs): """Actual benchmark implementation.""" - benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), - Path(codeflash_benchmark_plugin.project_root)) + benchmark_module_path = module_name_from_file_path( + Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root) + ) benchmark_function_name = self.request.node.name line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack # Set env vars @@ -278,7 +276,8 @@ def _run_benchmark(self, func, *args, **kwargs): codeflash_trace.function_call_count = 0 # Add to the benchmark timings buffer codeflash_benchmark_plugin.benchmark_timings.append( - (benchmark_module_path, benchmark_function_name, line_number, end - start)) + (benchmark_module_path, benchmark_function_name, line_number, end - start) + ) return result @@ -290,4 +289,5 @@ def benchmark(request): return CodeFlashBenchmarkPlugin.Benchmark(request) + codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()