11from __future__ import annotations
22
3+ import importlib .util
34import os
45import sqlite3
56import sys
67import time
78from pathlib import Path
9+ from typing import TYPE_CHECKING
810
911import pytest
1012
1113from codeflash .benchmarking .codeflash_trace import codeflash_trace
1214from codeflash .code_utils .code_utils import module_name_from_file_path
13- from codeflash .models .models import BenchmarkKey
15+
16+ if TYPE_CHECKING :
17+ from codeflash .models .models import BenchmarkKey
18+
19+ IS_PYTEST_BENCHMARK_INSTALLED = importlib .util .find_spec ("pytest_benchmark" ) is not None
1420
1521
1622class CodeFlashBenchmarkPlugin :
@@ -71,6 +77,8 @@ def close(self) -> None:
7177
7278 @staticmethod
7379 def get_function_benchmark_timings (trace_path : Path ) -> dict [str , dict [BenchmarkKey , int ]]:
80+ from codeflash .models .models import BenchmarkKey
81+
7482 """Process the trace file and extract timing data for all functions.
7583
7684 Args:
@@ -131,6 +139,8 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
131139
132140 @staticmethod
133141 def get_benchmark_timings (trace_path : Path ) -> dict [BenchmarkKey , int ]:
142+ from codeflash .models .models import BenchmarkKey
143+
134144 """Extract total benchmark timings from trace files.
135145
136146 Args:
@@ -199,23 +209,6 @@ def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, AR
199209 # Close the database connection
200210 self .close ()
201211
202- @staticmethod
203- def pytest_addoption (parser : pytest .Parser ) -> None :
204- parser .addoption ("--codeflash-trace" , action = "store_true" , default = False , help = "Enable CodeFlash tracing" )
205-
206- @staticmethod
207- def pytest_plugin_registered (plugin , manager ) -> None : # noqa: ANN001
208- # Not necessary since run with -p no:benchmark, but just in case
209- if hasattr (plugin , "name" ) and plugin .name == "pytest-benchmark" :
210- manager .unregister (plugin )
211-
212- @staticmethod
213- def pytest_configure (config : pytest .Config ) -> None :
214- """Register the benchmark marker."""
215- config .addinivalue_line (
216- "markers" , "benchmark: mark test as a benchmark that should be run with codeflash tracing"
217- )
218-
219212 @staticmethod
220213 def pytest_collection_modifyitems (config : pytest .Config , items : list [pytest .Item ]) -> None :
221214 # Skip tests that don't have the benchmark fixture
@@ -259,8 +252,9 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
259252 def _run_benchmark (self , func , * args , ** kwargs ): # noqa: ANN001, ANN002, ANN003, ANN202
260253 """Actual benchmark implementation."""
261254 benchmark_module_path = module_name_from_file_path (
262- Path (str (self .request .node .fspath )), Path (codeflash_benchmark_plugin .project_root )
255+ Path (str (self .request .node .fspath )), Path (codeflash_benchmark_plugin .project_root ), traverse_up = True
263256 )
257+
264258 benchmark_function_name = self .request .node .name
265259 line_number = int (str (sys ._getframe (2 ).f_lineno )) # 2 frames up in the call stack # noqa: SLF001
266260 # Set env vars
@@ -286,13 +280,28 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003
286280
287281 return result
288282
289- @staticmethod
290- @pytest .fixture
291- def benchmark (request : pytest .FixtureRequest ) -> object :
292- if not request .config .getoption ("--codeflash-trace" ):
293- return None
294283
295- return CodeFlashBenchmarkPlugin . Benchmark ( request )
284+ codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin ( )
296285
297286
298- codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin ()
287+ def pytest_configure (config : pytest .Config ) -> None :
288+ """Register the benchmark marker and disable conflicting plugins."""
289+ config .addinivalue_line ("markers" , "benchmark: mark test as a benchmark that should be run with codeflash tracing" )
290+
291+ if config .getoption ("--codeflash-trace" ) and IS_PYTEST_BENCHMARK_INSTALLED :
292+ config .option .benchmark_disable = True
293+ config .pluginmanager .set_blocked ("pytest_benchmark" )
294+ config .pluginmanager .set_blocked ("pytest-benchmark" )
295+
296+
297+ def pytest_addoption (parser : pytest .Parser ) -> None :
298+ parser .addoption (
299+ "--codeflash-trace" , action = "store_true" , default = False , help = "Enable CodeFlash tracing for benchmarks"
300+ )
301+
302+
303+ @pytest .fixture
304+ def benchmark (request : pytest .FixtureRequest ) -> object :
305+ if not request .config .getoption ("--codeflash-trace" ):
306+ return lambda func , * args , ** kwargs : func (* args , ** kwargs )
307+ return codeflash_benchmark_plugin .Benchmark (request )
0 commit comments