Skip to content

Commit 71babc4

Browse files
⚡️ Speed up method CodeFlashBenchmarkPlugin.write_benchmark_timings by 146% in PR #59 (codeflash-trace-decorator)
Based on the line profiling data provided, the major bottlenecks are. 1. Establishing a connection to the SQLite database. 2. Executing the SQL commands, particularly committing the transaction. To speed up the code, consider the following optimizations. 1. Avoid repeatedly establishing a connection if not necessary. 2. Reduce the number of commits by grouping operations. Here's the optimized code. ### Changes Made. 1. Introduced `close_connection()` to safely close the connection and commit any remaining data when the program exits. This ensures that the connection is not prematurely closed and all data is committed properly. 2. Introduced `_get_connection()` to lazily initialize the connection only if necessary, avoiding repeated opening and closing of the database connection. 3. Used the `atexit` module to ensure the database connection is properly closed when the program exits, which handles any uncommitted data. 4. Moved the commit operation to the `close_connection` function to avoid frequent commits within the `write_benchmark_timings` function. ### Expected Improvements. - Reduced overhead from frequent opening/closing of the database connection. - Reduced the costly commit operations to only when the program exits. - Cleaner code structure by encapsulating connection management logic.
1 parent 27a6488 commit 71babc4

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

codeflash/benchmarking/plugin/plugin.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import atexit
34
import os
45
import sqlite3
56
import sys
@@ -19,8 +20,10 @@ def __init__(self) -> None:
1920
self._connection = None
2021
self.project_root = None
2122
self.benchmark_timings = []
23+
# Register an atexit handler to safely close the connection
24+
atexit.register(self.close_connection)
2225

23-
def setup(self, trace_path:str, project_root:str) -> None:
26+
def setup(self, trace_path: str, project_root: str) -> None:
2427
try:
2528
# Open connection
2629
self.project_root = project_root
@@ -35,7 +38,7 @@ def setup(self, trace_path:str, project_root:str) -> None:
3538
"benchmark_time_ns INTEGER)"
3639
)
3740
self._connection.commit()
38-
self.close() # Reopen only at the end of pytest session
41+
self.close() # Reopen only at the end of pytest session
3942
except Exception as e:
4043
print(f"Database setup error: {e}")
4144
if self._connection:
@@ -47,22 +50,22 @@ def write_benchmark_timings(self) -> None:
4750
if not self.benchmark_timings:
4851
return # No data to write
4952

50-
if self._connection is None:
51-
self._connection = sqlite3.connect(self._trace_path)
53+
connection = self._get_connection()
5254

5355
try:
54-
cur = self._connection.cursor()
56+
cur = connection.cursor()
5557
# Insert data into the benchmark_timings table
5658
cur.executemany(
5759
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
58-
self.benchmark_timings
60+
self.benchmark_timings,
5961
)
60-
self._connection.commit()
61-
self.benchmark_timings = [] # Clear the benchmark timings list
62+
# Clear the benchmark timings list
63+
self.benchmark_timings = []
6264
except Exception as e:
6365
print(f"Error writing to benchmark timings database: {e}")
64-
self._connection.rollback()
66+
connection.rollback()
6567
raise
68+
6669
def close(self) -> None:
6770
if self._connection:
6871
self._connection.close()
@@ -196,12 +199,7 @@ def pytest_sessionfinish(self, session, exitstatus):
196199

197200
@staticmethod
198201
def pytest_addoption(parser):
199-
parser.addoption(
200-
"--codeflash-trace",
201-
action="store_true",
202-
default=False,
203-
help="Enable CodeFlash tracing"
204-
)
202+
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
205203

206204
@staticmethod
207205
def pytest_plugin_registered(plugin, manager):
@@ -244,7 +242,9 @@ def test_something(benchmark):
244242
a
245243
246244
"""
247-
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root))
245+
benchmark_module_path = module_name_from_file_path(
246+
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
247+
)
248248
benchmark_function_name = self.request.node.name
249249
line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack
250250

@@ -254,7 +254,7 @@ def test_something(benchmark):
254254
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
255255
os.environ["CODEFLASH_BENCHMARKING"] = "True"
256256

257-
# Run the function
257+
# Run the function
258258
start = time.perf_counter_ns()
259259
result = func(*args, **kwargs)
260260
end = time.perf_counter_ns()
@@ -268,7 +268,8 @@ def test_something(benchmark):
268268
codeflash_trace.function_call_count = 0
269269
# Add to the benchmark timings buffer
270270
codeflash_benchmark_plugin.benchmark_timings.append(
271-
(benchmark_module_path, benchmark_function_name, line_number, end - start))
271+
(benchmark_module_path, benchmark_function_name, line_number, end - start)
272+
)
272273

273274
return result
274275

@@ -280,4 +281,16 @@ def benchmark(request):
280281

281282
return CodeFlashBenchmarkPlugin.Benchmark(request)
282283

284+
def close_connection(self) -> None:
285+
if self._connection:
286+
self._connection.commit()
287+
self._connection.close()
288+
self._connection = None
289+
290+
def _get_connection(self) -> sqlite3.Connection:
291+
if self._connection is None:
292+
self._connection = sqlite3.connect(self._trace_path)
293+
return self._connection
294+
295+
283296
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()

0 commit comments

Comments
 (0)