Skip to content

Commit 32d9919

Browse files
committed
add the plugin as a pytest entry point
1 parent bb888a3 commit 32d9919

File tree

4 files changed

+41
-40
lines changed

4 files changed

+41
-40
lines changed

codeflash/benchmarking/plugin/plugin.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import importlib.util
34
import os
45
import sqlite3
56
import sys
@@ -12,6 +13,8 @@
1213
from codeflash.benchmarking.codeflash_trace import codeflash_trace
1314
from codeflash.code_utils.code_utils import module_name_from_file_path
1415

16+
IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
17+
1518

1619
@dataclass(frozen=True)
1720
class BenchmarkKey:
@@ -212,19 +215,18 @@ def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, AR
212215
def pytest_addoption(parser: pytest.Parser) -> None:
213216
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
214217

215-
@staticmethod
216-
def pytest_plugin_registered(plugin, manager) -> None: # noqa: ANN001
217-
# Not necessary since run with -p no:benchmark, but just in case
218-
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
219-
manager.unregister(plugin)
220-
221218
@staticmethod
222219
def pytest_configure(config: pytest.Config) -> None:
223-
"""Register the benchmark marker."""
220+
"""Register the benchmark marker and disable conflicting plugins."""
224221
config.addinivalue_line(
225222
"markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing"
226223
)
227224

225+
if config.getoption("--codeflash-trace") and IS_PYTEST_BENCHMARK_INSTALLED:
226+
object.__setattr__(config.option, "benchmark_disable", True)
227+
config.pluginmanager.set_blocked("pytest_benchmark")
228+
config.pluginmanager.set_blocked("pytest-benchmark")
229+
228230
@staticmethod
229231
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
230232
# Skip tests that don't have the benchmark marker
@@ -297,3 +299,20 @@ def benchmark(request: pytest.FixtureRequest) -> object:
297299

298300

299301
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
302+
303+
304+
def pytest_configure(config: pytest.Config) -> None:
305+
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")
306+
307+
308+
def pytest_addoption(parser: pytest.Parser) -> None:
309+
parser.addoption(
310+
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
311+
)
312+
313+
314+
@pytest.fixture
315+
def benchmark(request: pytest.FixtureRequest) -> object:
316+
if not request.config.getoption("--codeflash-trace"):
317+
return lambda func, *args, **kwargs: func(*args, **kwargs)
318+
return codeflash_benchmark_plugin.Benchmark(request)

codeflash/optimization/optimizer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,13 @@ def run_benchmarks(
9090
logger.info(
9191
f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization"
9292
)
93-
raise SystemExit # noqa: TRY301
94-
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file)
95-
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file)
96-
function_to_results = validate_and_format_benchmark_table(
97-
function_benchmark_timings, total_benchmark_timings
98-
)
99-
print_benchmark_table(function_to_results)
93+
else:
94+
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file)
95+
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file)
96+
function_to_results = validate_and_format_benchmark_table(
97+
function_benchmark_timings, total_benchmark_timings
98+
)
99+
print_benchmark_table(function_to_results)
100100
except Exception as e:
101101
logger.info(f"Error while tracing existing benchmarks: {e}")
102102
logger.info("Information on existing benchmarks will not be available for this run.")

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ dev = [
6767
"types-openpyxl>=3.1.5.20241020",
6868
"types-regex>=2024.9.11.20240912",
6969
"types-python-dateutil>=2.9.0.20241003",
70-
"pytest-benchmark>=5.1.0",
7170
"types-gevent>=24.11.0.20241230,<25",
7271
"types-greenlet>=3.1.0.20241221,<4",
7372
"types-pexpect>=4.9.0.20241208,<5",
@@ -300,3 +299,6 @@ markers = [
300299
[build-system]
301300
requires = ["hatchling", "uv-dynamic-versioning"]
302301
build-backend = "hatchling.build"
302+
303+
[project.entry-points.pytest11]
304+
codeflash = "codeflash.benchmarking.plugin.plugin"

uv.lock

Lines changed: 5 additions & 25 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)