diff --git a/codeflash/benchmarking/plugin/plugin.py b/codeflash/benchmarking/plugin/plugin.py index c7c11c6d4..4d5e7261d 100644 --- a/codeflash/benchmarking/plugin/plugin.py +++ b/codeflash/benchmarking/plugin/plugin.py @@ -20,7 +20,7 @@ def __init__(self) -> None: self.project_root = None self.benchmark_timings = [] - def setup(self, trace_path:str, project_root:str) -> None: + def setup(self, trace_path: str, project_root: str) -> None: try: # Open connection self.project_root = project_root @@ -35,7 +35,7 @@ def setup(self, trace_path:str, project_root:str) -> None: "benchmark_time_ns INTEGER)" ) self._connection.commit() - self.close() # Reopen only at the end of pytest session + self.close() # Reopen only at the end of pytest session except Exception as e: print(f"Database setup error: {e}") if self._connection: @@ -49,20 +49,23 @@ def write_benchmark_timings(self) -> None: if self._connection is None: self._connection = sqlite3.connect(self._trace_path) + self._connection.execute("PRAGMA journal_mode = WAL") + self._connection.execute("PRAGMA synchronous = NORMAL") try: cur = self._connection.cursor() - # Insert data into the benchmark_timings table - cur.executemany( - "INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)", - self.benchmark_timings - ) + # Prepare SQL statement only once + insert_query = "INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)" + + # Use `executemany` to insert data into the benchmark_timings table + cur.executemany(insert_query, self.benchmark_timings) self._connection.commit() - self.benchmark_timings = [] # Clear the benchmark timings list - except Exception as e: + self.benchmark_timings.clear() # Clear the benchmark timings list using clear() for slight efficiency gain + except sqlite3.Error as e: print(f"Error writing to benchmark timings database: {e}") self._connection.rollback() raise + def close(self) -> None: if self._connection: self._connection.close() @@ -196,12 +199,7 @@ def pytest_sessionfinish(self, session, exitstatus): @staticmethod def pytest_addoption(parser): - parser.addoption( - "--codeflash-trace", - action="store_true", - default=False, - help="Enable CodeFlash tracing" - ) + parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing") @staticmethod def pytest_plugin_registered(plugin, manager): @@ -244,7 +242,9 @@ def test_something(benchmark): a """ - benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)) + benchmark_module_path = module_name_from_file_path( + Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root) + ) benchmark_function_name = self.request.node.name line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack @@ -254,7 +254,7 @@ def test_something(benchmark): os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number) os.environ["CODEFLASH_BENCHMARKING"] = "True" - # Run the function + # Run the function start = time.perf_counter_ns() result = func(*args, **kwargs) end = time.perf_counter_ns() @@ -268,7 +268,8 @@ def test_something(benchmark): codeflash_trace.function_call_count = 0 # Add to the benchmark timings buffer codeflash_benchmark_plugin.benchmark_timings.append( - (benchmark_module_path, benchmark_function_name, line_number, end - start)) + (benchmark_module_path, benchmark_function_name, line_number, end - start) + ) return result @@ -280,4 +281,5 @@ def benchmark(request): return CodeFlashBenchmarkPlugin.Benchmark(request) + codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()