Skip to content

Commit ad90c2d

Browse files
committed
Benchmark Fixture fixes
Update plugin.py
1 parent 47e29ec commit ad90c2d

File tree

1 file changed

+38
-4
lines changed

1 file changed

+38
-4
lines changed

codeflash/benchmarking/plugin/plugin.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
if TYPE_CHECKING:
1717
from codeflash.models.models import BenchmarkKey
1818

19-
IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
19+
try:
20+
import pytest_benchmark
21+
22+
PYTEST_BENCHMARK_INSTALLED = True
23+
except ImportError:
24+
PYTEST_BENCHMARK_INSTALLED = False
2025

2126

2227
class CodeFlashBenchmarkPlugin:
@@ -288,7 +293,7 @@ def pytest_configure(config: pytest.Config) -> None:
288293
"""Register the benchmark marker and disable conflicting plugins."""
289294
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")
290295

291-
if config.getoption("--codeflash-trace") and IS_PYTEST_BENCHMARK_INSTALLED:
296+
if config.getoption("--codeflash-trace") and PYTEST_BENCHMARK_INSTALLED:
292297
config.option.benchmark_disable = True
293298
config.pluginmanager.set_blocked("pytest_benchmark")
294299
config.pluginmanager.set_blocked("pytest-benchmark")
@@ -302,6 +307,35 @@ def pytest_addoption(parser: pytest.Parser) -> None:
302307

303308
@pytest.fixture
304309
def benchmark(request: pytest.FixtureRequest) -> object:
305-
if not request.config.getoption("--codeflash-trace"):
310+
"""Benchmark fixture that works with or without pytest-benchmark installed."""
311+
config = request.config
312+
313+
# If --codeflash-trace is enabled, use our implementation
314+
if config.getoption("--codeflash-trace"):
315+
return codeflash_benchmark_plugin.Benchmark(request)
316+
317+
# If pytest-benchmark is installed and --codeflash-trace is not enabled,
318+
# return the normal pytest-benchmark fixture
319+
if PYTEST_BENCHMARK_INSTALLED:
320+
from pytest_benchmark.fixture import BenchmarkFixture as BSF # noqa: N814
321+
322+
bs = getattr(config, "_benchmarksession", None)
323+
if bs and bs.skip:
324+
pytest.skip("Benchmarks are skipped (--benchmark-skip was used).")
325+
326+
node = request.node
327+
marker = node.get_closest_marker("benchmark")
328+
options = dict(marker.kwargs) if marker else {}
329+
330+
if bs:
331+
return BSF(
332+
node,
333+
add_stats=bs.benchmarks.append,
334+
logger=bs.logger,
335+
warner=request.node.warn,
336+
disabled=bs.disabled,
337+
**dict(bs.options, **options),
338+
)
306339
return lambda func, *args, **kwargs: func(*args, **kwargs)
307-
return codeflash_benchmark_plugin.Benchmark(request)
340+
341+
return lambda func, *args, **kwargs: func(*args, **kwargs)

0 commit comments

Comments
 (0)