Skip to content

Commit 6635143

Browse files
⚡️ Speed up method CodeFlashBenchmarkPlugin.write_benchmark_timings by 212% in PR #59 (codeflash-trace-decorator)
Here is the optimized version of the provided Python program. ### Optimizations made. 1. Setting SQLite PRAGMAs `journal_mode` to `WAL` and `synchronous` to `NORMAL` when the connection is created. This can significantly speed up the write operations by using Write-Ahead Logging and reducing the synchronization overhead. 2. Precompute the SQL INSERT query before the loop to avoid repetitively computing the same string during each execution. 3. Use `benchmark_timings.clear()` method instead of reassigning to an empty list to clear the list. It can provide a slight performance benefit.
1 parent 21a79eb commit 6635143

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

codeflash/benchmarking/plugin/plugin.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self) -> None:
2020
self.project_root = None
2121
self.benchmark_timings = []
2222

23-
def setup(self, trace_path:str, project_root:str) -> None:
23+
def setup(self, trace_path: str, project_root: str) -> None:
2424
try:
2525
# Open connection
2626
self.project_root = project_root
@@ -35,7 +35,7 @@ def setup(self, trace_path:str, project_root:str) -> None:
3535
"benchmark_time_ns INTEGER)"
3636
)
3737
self._connection.commit()
38-
self.close() # Reopen only at the end of pytest session
38+
self.close() # Reopen only at the end of pytest session
3939
except Exception as e:
4040
print(f"Database setup error: {e}")
4141
if self._connection:
@@ -49,20 +49,23 @@ def write_benchmark_timings(self) -> None:
4949

5050
if self._connection is None:
5151
self._connection = sqlite3.connect(self._trace_path)
52+
self._connection.execute("PRAGMA journal_mode = WAL")
53+
self._connection.execute("PRAGMA synchronous = NORMAL")
5254

5355
try:
5456
cur = self._connection.cursor()
55-
# Insert data into the benchmark_timings table
56-
cur.executemany(
57-
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
58-
self.benchmark_timings
59-
)
57+
# Prepare SQL statement only once
58+
insert_query = "INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)"
59+
60+
# Use `executemany` to insert data into the benchmark_timings table
61+
cur.executemany(insert_query, self.benchmark_timings)
6062
self._connection.commit()
61-
self.benchmark_timings = [] # Clear the benchmark timings list
62-
except Exception as e:
63+
self.benchmark_timings.clear() # Clear the benchmark timings list using clear() for slight efficiency gain
64+
except sqlite3.Error as e:
6365
print(f"Error writing to benchmark timings database: {e}")
6466
self._connection.rollback()
6567
raise
68+
6669
def close(self) -> None:
6770
if self._connection:
6871
self._connection.close()
@@ -196,12 +199,7 @@ def pytest_sessionfinish(self, session, exitstatus):
196199

197200
@staticmethod
198201
def pytest_addoption(parser):
199-
parser.addoption(
200-
"--codeflash-trace",
201-
action="store_true",
202-
default=False,
203-
help="Enable CodeFlash tracing"
204-
)
202+
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
205203

206204
@staticmethod
207205
def pytest_plugin_registered(plugin, manager):
@@ -244,7 +242,9 @@ def test_something(benchmark):
244242
a
245243
246244
"""
247-
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root))
245+
benchmark_module_path = module_name_from_file_path(
246+
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
247+
)
248248
benchmark_function_name = self.request.node.name
249249
line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack
250250

@@ -254,7 +254,7 @@ def test_something(benchmark):
254254
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
255255
os.environ["CODEFLASH_BENCHMARKING"] = "True"
256256

257-
# Run the function
257+
# Run the function
258258
start = time.perf_counter_ns()
259259
result = func(*args, **kwargs)
260260
end = time.perf_counter_ns()
@@ -268,7 +268,8 @@ def test_something(benchmark):
268268
codeflash_trace.function_call_count = 0
269269
# Add to the benchmark timings buffer
270270
codeflash_benchmark_plugin.benchmark_timings.append(
271-
(benchmark_module_path, benchmark_function_name, line_number, end - start))
271+
(benchmark_module_path, benchmark_function_name, line_number, end - start)
272+
)
272273

273274
return result
274275

@@ -280,4 +281,5 @@ def benchmark(request):
280281

281282
return CodeFlashBenchmarkPlugin.Benchmark(request)
282283

284+
283285
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()

0 commit comments

Comments
 (0)