diff --git a/codeflash/benchmarking/plugin/plugin.py b/codeflash/benchmarking/plugin/plugin.py index c7c11c6d4..8e312c41f 100644 --- a/codeflash/benchmarking/plugin/plugin.py +++ b/codeflash/benchmarking/plugin/plugin.py @@ -20,7 +20,7 @@ def __init__(self) -> None: self.project_root = None self.benchmark_timings = [] - def setup(self, trace_path:str, project_root:str) -> None: + def setup(self, trace_path: str, project_root: str) -> None: try: # Open connection self.project_root = project_root @@ -35,7 +35,7 @@ def setup(self, trace_path:str, project_root:str) -> None: "benchmark_time_ns INTEGER)" ) self._connection.commit() - self.close() # Reopen only at the end of pytest session + self.close() # Reopen only at the end of pytest session except Exception as e: print(f"Database setup error: {e}") if self._connection: @@ -55,14 +55,15 @@ def write_benchmark_timings(self) -> None: # Insert data into the benchmark_timings table cur.executemany( "INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)", - self.benchmark_timings + self.benchmark_timings, ) self._connection.commit() - self.benchmark_timings = [] # Clear the benchmark timings list + self.benchmark_timings = [] # Clear the benchmark timings list except Exception as e: print(f"Error writing to benchmark timings database: {e}") self._connection.rollback() raise + def close(self) -> None: if self._connection: self._connection.close() @@ -196,12 +197,7 @@ def pytest_sessionfinish(self, session, exitstatus): @staticmethod def pytest_addoption(parser): - parser.addoption( - "--codeflash-trace", - action="store_true", - default=False, - help="Enable CodeFlash tracing" - ) + parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing") @staticmethod def pytest_plugin_registered(plugin, manager): @@ -216,11 +212,21 @@ def pytest_collection_modifyitems(config, items): return skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture") + + items_with_benchmark = [] + items_without_benchmark = [] + for item in items: if hasattr(item, "fixturenames") and "benchmark" in item.fixturenames: - continue + items_with_benchmark.append(item) + else: + items_without_benchmark.append(item) + + for item in items_without_benchmark: item.add_marker(skip_no_benchmark) + items[:] = items_with_benchmark + items_without_benchmark + # Benchmark fixture class Benchmark: def __init__(self, request): @@ -244,7 +250,9 @@ def test_something(benchmark): a """ - benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)) + benchmark_module_path = module_name_from_file_path( + Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root) + ) benchmark_function_name = self.request.node.name line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack @@ -254,7 +262,7 @@ def test_something(benchmark): os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number) os.environ["CODEFLASH_BENCHMARKING"] = "True" - # Run the function + # Run the function start = time.perf_counter_ns() result = func(*args, **kwargs) end = time.perf_counter_ns() @@ -268,7 +276,8 @@ def test_something(benchmark): codeflash_trace.function_call_count = 0 # Add to the benchmark timings buffer codeflash_benchmark_plugin.benchmark_timings.append( - (benchmark_module_path, benchmark_function_name, line_number, end - start)) + (benchmark_module_path, benchmark_function_name, line_number, end - start) + ) return result @@ -280,4 +289,5 @@ def benchmark(request): return CodeFlashBenchmarkPlugin.Benchmark(request) + codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()