Skip to content

Commit df1d598

Browse files
⚡️ Speed up method CodeFlashBenchmarkPlugin.pytest_collection_modifyitems by 161% in PR #59 (codeflash-trace-decorator)
## Explanation of Changes. 1. **Optimize the Loop in `pytest_collection_modifyitems` Method**. - Instead of checking the fixture names and conditionally adding the marker within the same loop, we separate the items into two lists: one for items with the `benchmark` fixture and one for items without it. - This optimization helps to reduce the possible overhead of repeatedly calling `add_marker` by first classifying the items and then applying the marker only to those necessary. - Finally, we concatenate the lists to retain the original order, except with tests without benchmark fixtures getting the skip marker. Note: The item list ordering at the end of the method may not be necessary depending on the context, but this ensures that we process markers efficiently without changing the original relative order of test items.
1 parent 21a79eb commit df1d598

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

codeflash/benchmarking/plugin/plugin.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self) -> None:
2020
self.project_root = None
2121
self.benchmark_timings = []
2222

23-
def setup(self, trace_path:str, project_root:str) -> None:
23+
def setup(self, trace_path: str, project_root: str) -> None:
2424
try:
2525
# Open connection
2626
self.project_root = project_root
@@ -35,7 +35,7 @@ def setup(self, trace_path:str, project_root:str) -> None:
3535
"benchmark_time_ns INTEGER)"
3636
)
3737
self._connection.commit()
38-
self.close() # Reopen only at the end of pytest session
38+
self.close() # Reopen only at the end of pytest session
3939
except Exception as e:
4040
print(f"Database setup error: {e}")
4141
if self._connection:
@@ -55,14 +55,15 @@ def write_benchmark_timings(self) -> None:
5555
# Insert data into the benchmark_timings table
5656
cur.executemany(
5757
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
58-
self.benchmark_timings
58+
self.benchmark_timings,
5959
)
6060
self._connection.commit()
61-
self.benchmark_timings = [] # Clear the benchmark timings list
61+
self.benchmark_timings = [] # Clear the benchmark timings list
6262
except Exception as e:
6363
print(f"Error writing to benchmark timings database: {e}")
6464
self._connection.rollback()
6565
raise
66+
6667
def close(self) -> None:
6768
if self._connection:
6869
self._connection.close()
@@ -196,12 +197,7 @@ def pytest_sessionfinish(self, session, exitstatus):
196197

197198
@staticmethod
198199
def pytest_addoption(parser):
199-
parser.addoption(
200-
"--codeflash-trace",
201-
action="store_true",
202-
default=False,
203-
help="Enable CodeFlash tracing"
204-
)
200+
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
205201

206202
@staticmethod
207203
def pytest_plugin_registered(plugin, manager):
@@ -216,11 +212,21 @@ def pytest_collection_modifyitems(config, items):
216212
return
217213

218214
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
215+
216+
items_with_benchmark = []
217+
items_without_benchmark = []
218+
219219
for item in items:
220220
if hasattr(item, "fixturenames") and "benchmark" in item.fixturenames:
221-
continue
221+
items_with_benchmark.append(item)
222+
else:
223+
items_without_benchmark.append(item)
224+
225+
for item in items_without_benchmark:
222226
item.add_marker(skip_no_benchmark)
223227

228+
items[:] = items_with_benchmark + items_without_benchmark
229+
224230
# Benchmark fixture
225231
class Benchmark:
226232
def __init__(self, request):
@@ -244,7 +250,9 @@ def test_something(benchmark):
244250
a
245251
246252
"""
247-
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root))
253+
benchmark_module_path = module_name_from_file_path(
254+
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
255+
)
248256
benchmark_function_name = self.request.node.name
249257
line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack
250258

@@ -254,7 +262,7 @@ def test_something(benchmark):
254262
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
255263
os.environ["CODEFLASH_BENCHMARKING"] = "True"
256264

257-
# Run the function
265+
# Run the function
258266
start = time.perf_counter_ns()
259267
result = func(*args, **kwargs)
260268
end = time.perf_counter_ns()
@@ -268,7 +276,8 @@ def test_something(benchmark):
268276
codeflash_trace.function_call_count = 0
269277
# Add to the benchmark timings buffer
270278
codeflash_benchmark_plugin.benchmark_timings.append(
271-
(benchmark_module_path, benchmark_function_name, line_number, end - start))
279+
(benchmark_module_path, benchmark_function_name, line_number, end - start)
280+
)
272281

273282
return result
274283

@@ -280,4 +289,5 @@ def benchmark(request):
280289

281290
return CodeFlashBenchmarkPlugin.Benchmark(request)
282291

292+
283293
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()

0 commit comments

Comments
 (0)