|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import importlib.util |
3 | 4 | import os |
4 | 5 | import sqlite3 |
5 | 6 | import sys |
|
12 | 13 | from codeflash.benchmarking.codeflash_trace import codeflash_trace |
13 | 14 | from codeflash.code_utils.code_utils import module_name_from_file_path |
14 | 15 |
|
| 16 | +IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None |
| 17 | + |
15 | 18 |
|
16 | 19 | @dataclass(frozen=True) |
17 | 20 | class BenchmarkKey: |
@@ -212,19 +215,18 @@ def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, AR |
212 | 215 | def pytest_addoption(parser: pytest.Parser) -> None: |
213 | 216 | parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing") |
214 | 217 |
|
215 | | - @staticmethod |
216 | | - def pytest_plugin_registered(plugin, manager) -> None: # noqa: ANN001 |
217 | | - # Not necessary since run with -p no:benchmark, but just in case |
218 | | - if hasattr(plugin, "name") and plugin.name == "pytest-benchmark": |
219 | | - manager.unregister(plugin) |
220 | | - |
221 | 218 | @staticmethod |
222 | 219 | def pytest_configure(config: pytest.Config) -> None: |
223 | | - """Register the benchmark marker.""" |
| 220 | + """Register the benchmark marker and disable conflicting plugins.""" |
224 | 221 | config.addinivalue_line( |
225 | 222 | "markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing" |
226 | 223 | ) |
227 | 224 |
|
| 225 | + if config.getoption("--codeflash-trace") and IS_PYTEST_BENCHMARK_INSTALLED: |
| 226 | + object.__setattr__(config.option, "benchmark_disable", True) |
| 227 | + config.pluginmanager.set_blocked("pytest_benchmark") |
| 228 | + config.pluginmanager.set_blocked("pytest-benchmark") |
| 229 | + |
228 | 230 | @staticmethod |
229 | 231 | def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: |
230 | 232 | # Skip tests that don't have the benchmark marker |
@@ -297,3 +299,20 @@ def benchmark(request: pytest.FixtureRequest) -> object: |
297 | 299 |
|
298 | 300 |
|
299 | 301 | codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin() |
| 302 | + |
| 303 | + |
| 304 | +def pytest_configure(config: pytest.Config) -> None: |
| 305 | + config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing") |
| 306 | + |
| 307 | + |
| 308 | +def pytest_addoption(parser: pytest.Parser) -> None: |
| 309 | + parser.addoption( |
| 310 | + "--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks" |
| 311 | + ) |
| 312 | + |
| 313 | + |
| 314 | +@pytest.fixture |
| 315 | +def benchmark(request: pytest.FixtureRequest) -> object: |
| 316 | + if not request.config.getoption("--codeflash-trace"): |
| 317 | + return lambda func, *args, **kwargs: func(*args, **kwargs) |
| 318 | + return codeflash_benchmark_plugin.Benchmark(request) |
0 commit comments