1616if TYPE_CHECKING :
1717 from codeflash .models .models import BenchmarkKey
1818
19- IS_PYTEST_BENCHMARK_INSTALLED = importlib .util .find_spec ("pytest_benchmark" ) is not None
19+ try :
20+ import pytest_benchmark
21+
22+ PYTEST_BENCHMARK_INSTALLED = True
23+ except ImportError :
24+ PYTEST_BENCHMARK_INSTALLED = False
2025
2126
2227class CodeFlashBenchmarkPlugin :
@@ -288,7 +293,7 @@ def pytest_configure(config: pytest.Config) -> None:
288293 """Register the benchmark marker and disable conflicting plugins."""
289294 config .addinivalue_line ("markers" , "benchmark: mark test as a benchmark that should be run with codeflash tracing" )
290295
291- if config .getoption ("--codeflash-trace" ) and IS_PYTEST_BENCHMARK_INSTALLED :
296+ if config .getoption ("--codeflash-trace" ) and PYTEST_BENCHMARK_INSTALLED :
292297 config .option .benchmark_disable = True
293298 config .pluginmanager .set_blocked ("pytest_benchmark" )
294299 config .pluginmanager .set_blocked ("pytest-benchmark" )
@@ -302,6 +307,35 @@ def pytest_addoption(parser: pytest.Parser) -> None:
302307
303308@pytest .fixture
304309def benchmark (request : pytest .FixtureRequest ) -> object :
305- if not request .config .getoption ("--codeflash-trace" ):
310+ """Benchmark fixture that works with or without pytest-benchmark installed."""
311+ config = request .config
312+
313+ # If --codeflash-trace is enabled, use our implementation
314+ if config .getoption ("--codeflash-trace" ):
315+ return codeflash_benchmark_plugin .Benchmark (request )
316+
317+ # If pytest-benchmark is installed and --codeflash-trace is not enabled,
318+ # return the normal pytest-benchmark fixture
319+ if PYTEST_BENCHMARK_INSTALLED :
320+ from pytest_benchmark .fixture import BenchmarkFixture as BSF # noqa: N814
321+
322+ bs = getattr (config , "_benchmarksession" , None )
323+ if bs and bs .skip :
324+ pytest .skip ("Benchmarks are skipped (--benchmark-skip was used)." )
325+
326+ node = request .node
327+ marker = node .get_closest_marker ("benchmark" )
328+ options = dict (marker .kwargs ) if marker else {}
329+
330+ if bs :
331+ return BSF (
332+ node ,
333+ add_stats = bs .benchmarks .append ,
334+ logger = bs .logger ,
335+ warner = request .node .warn ,
336+ disabled = bs .disabled ,
337+ ** dict (bs .options , ** options ),
338+ )
306339 return lambda func , * args , ** kwargs : func (* args , ** kwargs )
307- return codeflash_benchmark_plugin .Benchmark (request )
340+
341+ return lambda func , * args , ** kwargs : func (* args , ** kwargs )
0 commit comments