Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions codeflash/benchmarking/plugin/plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import atexit
import os
import sqlite3
import sys
Expand All @@ -19,8 +20,10 @@ def __init__(self) -> None:
self._connection = None
self.project_root = None
self.benchmark_timings = []
# Register an atexit handler to safely close the connection
atexit.register(self.close_connection)

def setup(self, trace_path:str, project_root:str) -> None:
def setup(self, trace_path: str, project_root: str) -> None:
try:
# Open connection
self.project_root = project_root
Expand All @@ -35,7 +38,7 @@ def setup(self, trace_path:str, project_root:str) -> None:
"benchmark_time_ns INTEGER)"
)
self._connection.commit()
self.close() # Reopen only at the end of pytest session
self.close() # Reopen only at the end of pytest session
except Exception as e:
print(f"Database setup error: {e}")
if self._connection:
Expand All @@ -47,22 +50,22 @@ def write_benchmark_timings(self) -> None:
if not self.benchmark_timings:
return # No data to write

if self._connection is None:
self._connection = sqlite3.connect(self._trace_path)
connection = self._get_connection()

try:
cur = self._connection.cursor()
cur = connection.cursor()
# Insert data into the benchmark_timings table
cur.executemany(
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
self.benchmark_timings
self.benchmark_timings,
)
self._connection.commit()
self.benchmark_timings = [] # Clear the benchmark timings list
# Clear the benchmark timings list
self.benchmark_timings = []
except Exception as e:
print(f"Error writing to benchmark timings database: {e}")
self._connection.rollback()
connection.rollback()
raise

def close(self) -> None:
if self._connection:
self._connection.close()
Expand Down Expand Up @@ -196,12 +199,7 @@ def pytest_sessionfinish(self, session, exitstatus):

@staticmethod
def pytest_addoption(parser):
parser.addoption(
"--codeflash-trace",
action="store_true",
default=False,
help="Enable CodeFlash tracing"
)
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")

@staticmethod
def pytest_plugin_registered(plugin, manager):
Expand Down Expand Up @@ -244,7 +242,9 @@ def test_something(benchmark):
a

"""
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root))
benchmark_module_path = module_name_from_file_path(
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
)
benchmark_function_name = self.request.node.name
line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack

Expand All @@ -254,7 +254,7 @@ def test_something(benchmark):
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
os.environ["CODEFLASH_BENCHMARKING"] = "True"

# Run the function
# Run the function
start = time.perf_counter_ns()
result = func(*args, **kwargs)
end = time.perf_counter_ns()
Expand All @@ -268,7 +268,8 @@ def test_something(benchmark):
codeflash_trace.function_call_count = 0
# Add to the benchmark timings buffer
codeflash_benchmark_plugin.benchmark_timings.append(
(benchmark_module_path, benchmark_function_name, line_number, end - start))
(benchmark_module_path, benchmark_function_name, line_number, end - start)
)

return result

Expand All @@ -280,4 +281,16 @@ def benchmark(request):

return CodeFlashBenchmarkPlugin.Benchmark(request)

def close_connection(self) -> None:
if self._connection:
self._connection.commit()
self._connection.close()
self._connection = None

def _get_connection(self) -> sqlite3.Connection:
if self._connection is None:
self._connection = sqlite3.connect(self._trace_path)
return self._connection


codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
Loading