diff --git a/codeflash/benchmarking/plugin/plugin.py b/codeflash/benchmarking/plugin/plugin.py index 09858601c..b969959e7 100644 --- a/codeflash/benchmarking/plugin/plugin.py +++ b/codeflash/benchmarking/plugin/plugin.py @@ -1,10 +1,13 @@ from __future__ import annotations + import os import sqlite3 import sys import time from pathlib import Path + import pytest + from codeflash.benchmarking.codeflash_trace import codeflash_trace from codeflash.models.models import BenchmarkKey @@ -15,7 +18,7 @@ def __init__(self) -> None: self._connection = None self.benchmark_timings = [] - def setup(self, trace_path:str) -> None: + def setup(self, trace_path: str) -> None: try: # Open connection self._trace_path = trace_path @@ -28,7 +31,7 @@ def setup(self, trace_path:str) -> None: "benchmark_time_ns INTEGER)" ) self._connection.commit() - self.close() # Reopen only at the end of pytest session + self.close() # Reopen only at the end of pytest session except Exception as e: print(f"Database setup error: {e}") if self._connection: @@ -42,20 +45,23 @@ def write_benchmark_timings(self) -> None: if self._connection is None: self._connection = sqlite3.connect(self._trace_path) + self._connection.execute("PRAGMA synchronous = OFF") + self._connection.execute("PRAGMA journal_mode = MEMORY") try: cur = self._connection.cursor() # Insert data into the benchmark_timings table cur.executemany( "INSERT INTO benchmark_timings (benchmark_file_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)", - self.benchmark_timings + self.benchmark_timings, ) self._connection.commit() - self.benchmark_timings = [] # Clear the benchmark timings list + self.benchmark_timings = [] # Clear the benchmark timings list except Exception as e: print(f"Error writing to benchmark timings database: {e}") self._connection.rollback() raise + def close(self) -> None: if self._connection: self._connection.close() @@ -189,12 +195,7 @@ def pytest_sessionfinish(self, session, exitstatus): @staticmethod def pytest_addoption(parser): - parser.addoption( - "--codeflash-trace", - action="store_true", - default=False, - help="Enable CodeFlash tracing" - ) + parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing") @staticmethod def pytest_plugin_registered(plugin, manager): @@ -246,7 +247,7 @@ def test_something(benchmark): os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number) os.environ["CODEFLASH_BENCHMARKING"] = "True" - # Run the function + # Run the function start = time.perf_counter_ns() result = func(*args, **kwargs) end = time.perf_counter_ns() @@ -260,7 +261,8 @@ def test_something(benchmark): codeflash_trace.function_call_count = 0 # Add to the benchmark timings buffer codeflash_benchmark_plugin.benchmark_timings.append( - (benchmark_file_path, benchmark_function_name, line_number, end - start)) + (benchmark_file_path, benchmark_function_name, line_number, end - start) + ) return result @@ -272,4 +274,5 @@ def benchmark(request): return CodeFlashBenchmarkPlugin.Benchmark(request) -codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin() \ No newline at end of file + +codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()