Skip to content

Commit 7e13327

Browse files
⚡️ Speed up method CodeFlashBenchmarkPlugin.write_benchmark_timings by 138% in PR #59 (codeflash-trace-decorator)
To optimize the `write_benchmark_timings` method for improved performance, a few changes can be made to minimize database connection overhead and improve exception handling. Here's a revised version of the program. **Changes and Improvements Made:** 1. **Connection Initialization:** Moved the connection initialization into a helper function `_initialize_connection` to avoid redundant checks and make the intention clear. 2. **Deferred Commit:** Introduced a flag `_need_commit` to defer the commit operation after inserting data. This will allow multiple write operations before committing, thus reducing the overhead of frequent commits. 3. **Transaction Management:** Added a `commit` method that performs the commit operation if necessary. This method can be called periodically or at the end of multiple write operations to finalize the changes. **Usage Considerations:** 1. Ensure the `commit` method is called after a series of write operations as it finalizes and commits the database transactions. 2. This implementation assumes that deferred commit is suitable for the program logic, and committing after multiple operations is acceptable. This implementation optimizes the original program by reducing the frequency of commits and potentially aligning multiple write operations together, thus improving overall performance.
1 parent 217e239 commit 7e13327

File tree

1 file changed

+34
-17
lines changed

1 file changed

+34
-17
lines changed

codeflash/benchmarking/plugin/plugin.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ def __init__(self) -> None:
2020
self.project_root = None
2121
self.benchmark_timings = []
2222

23-
def setup(self, trace_path:str, project_root:str) -> None:
23+
# Added a flag to indicate the need to commit after write operations
24+
self._need_commit = False
25+
26+
def setup(self, trace_path: str, project_root: str) -> None:
2427
try:
2528
# Open connection
2629
self.project_root = project_root
@@ -35,7 +38,7 @@ def setup(self, trace_path:str, project_root:str) -> None:
3538
"benchmark_time_ns INTEGER)"
3639
)
3740
self._connection.commit()
38-
self.close() # Reopen only at the end of pytest session
41+
self.close() # Reopen only at the end of pytest session
3942
except Exception as e:
4043
print(f"Database setup error: {e}")
4144
if self._connection:
@@ -47,22 +50,20 @@ def write_benchmark_timings(self) -> None:
4750
if not self.benchmark_timings:
4851
return # No data to write
4952

50-
if self._connection is None:
51-
self._connection = sqlite3.connect(self._trace_path)
53+
self._initialize_connection()
5254

5355
try:
5456
cur = self._connection.cursor()
5557
# Insert data into the benchmark_timings table
5658
cur.executemany(
5759
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
58-
self.benchmark_timings
60+
self.benchmark_timings,
5961
)
60-
self._connection.commit()
61-
self.benchmark_timings = [] # Clear the benchmark timings list
62+
self._need_commit = True # Mark the flag to commit later
6263
except Exception as e:
6364
print(f"Error writing to benchmark timings database: {e}")
64-
self._connection.rollback()
6565
raise
66+
6667
def close(self) -> None:
6768
if self._connection:
6869
self._connection.close()
@@ -196,12 +197,7 @@ def pytest_sessionfinish(self, session, exitstatus):
196197

197198
@staticmethod
198199
def pytest_addoption(parser):
199-
parser.addoption(
200-
"--codeflash-trace",
201-
action="store_true",
202-
default=False,
203-
help="Enable CodeFlash tracing"
204-
)
200+
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
205201

206202
@staticmethod
207203
def pytest_plugin_registered(plugin, manager):
@@ -244,7 +240,9 @@ def test_something(benchmark):
244240
a
245241
246242
"""
247-
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root))
243+
benchmark_module_path = module_name_from_file_path(
244+
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
245+
)
248246
benchmark_function_name = self.request.node.name
249247
line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack
250248

@@ -254,7 +252,7 @@ def test_something(benchmark):
254252
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
255253
os.environ["CODEFLASH_BENCHMARKING"] = "True"
256254

257-
# Run the function
255+
# Run the function
258256
start = time.perf_counter_ns()
259257
result = func(*args, **kwargs)
260258
end = time.perf_counter_ns()
@@ -268,7 +266,8 @@ def test_something(benchmark):
268266
codeflash_trace.function_call_count = 0
269267
# Add to the benchmark timings buffer
270268
codeflash_benchmark_plugin.benchmark_timings.append(
271-
(benchmark_module_path, benchmark_function_name, line_number, end - start))
269+
(benchmark_module_path, benchmark_function_name, line_number, end - start)
270+
)
272271

273272
return result
274273

@@ -280,4 +279,22 @@ def benchmark(request):
280279

281280
return CodeFlashBenchmarkPlugin.Benchmark(request)
282281

282+
def _initialize_connection(self) -> None:
283+
"""Initialize the database connection if not already initialized."""
284+
if self._connection is None:
285+
self._connection = sqlite3.connect(self._trace_path)
286+
287+
def commit(self) -> None:
288+
"""Commit the database transactions, if needed, and reset the benchmark timings."""
289+
if self._need_commit and self._connection:
290+
try:
291+
self._connection.commit()
292+
self.benchmark_timings = [] # Clear the benchmark timings list
293+
self._need_commit = False # Reset the commit flag
294+
except Exception as e:
295+
print(f"Error committing the benchmark timings database: {e}")
296+
self._connection.rollback()
297+
raise
298+
299+
283300
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()

0 commit comments

Comments
 (0)