11from __future__ import annotations
22
3+ import importlib .util
34import os
45import sqlite3
56import sys
67import time
78from pathlib import Path
9+ from typing import TYPE_CHECKING
810
911import pytest
1012
1113from codeflash .benchmarking .codeflash_trace import codeflash_trace
1214from codeflash .code_utils .code_utils import module_name_from_file_path
13- from codeflash .models .models import BenchmarkKey
15+
16+ if TYPE_CHECKING :
17+ from codeflash .models .models import BenchmarkKey
18+
19+ IS_PYTEST_BENCHMARK_INSTALLED = importlib .util .find_spec ("pytest_benchmark" ) is not None
1420
1521
1622# hello
@@ -72,6 +78,8 @@ def close(self) -> None:
7278
7379 @staticmethod
7480 def get_function_benchmark_timings (trace_path : Path ) -> dict [str , dict [BenchmarkKey , int ]]:
81+ from codeflash .models .models import BenchmarkKey
82+
7583 """Process the trace file and extract timing data for all functions.
7684
7785 Args:
@@ -132,6 +140,8 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
132140
133141 @staticmethod
134142 def get_benchmark_timings (trace_path : Path ) -> dict [BenchmarkKey , int ]:
143+ from codeflash .models .models import BenchmarkKey
144+
135145 """Extract total benchmark timings from trace files.
136146
137147 Args:
@@ -200,23 +210,6 @@ def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, AR
200210 # Close the database connection
201211 self .close ()
202212
203- @staticmethod
204- def pytest_addoption (parser : pytest .Parser ) -> None :
205- parser .addoption ("--codeflash-trace" , action = "store_true" , default = False , help = "Enable CodeFlash tracing" )
206-
207- @staticmethod
208- def pytest_plugin_registered (plugin , manager ) -> None : # noqa: ANN001
209- # Not necessary since run with -p no:benchmark, but just in case
210- if hasattr (plugin , "name" ) and plugin .name == "pytest-benchmark" :
211- manager .unregister (plugin )
212-
213- @staticmethod
214- def pytest_configure (config : pytest .Config ) -> None :
215- """Register the benchmark marker."""
216- config .addinivalue_line (
217- "markers" , "benchmark: mark test as a benchmark that should be run with codeflash tracing"
218- )
219-
220213 @staticmethod
221214 def pytest_collection_modifyitems (config : pytest .Config , items : list [pytest .Item ]) -> None :
222215 # Skip tests that don't have the benchmark fixture
@@ -260,8 +253,9 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
260253 def _run_benchmark (self , func , * args , ** kwargs ): # noqa: ANN001, ANN002, ANN003, ANN202
261254 """Actual benchmark implementation."""
262255 benchmark_module_path = module_name_from_file_path (
263- Path (str (self .request .node .fspath )), Path (codeflash_benchmark_plugin .project_root )
256+ Path (str (self .request .node .fspath )), Path (codeflash_benchmark_plugin .project_root ), traverse_up = True
264257 )
258+
265259 benchmark_function_name = self .request .node .name
266260 line_number = int (str (sys ._getframe (2 ).f_lineno )) # 2 frames up in the call stack # noqa: SLF001
267261 # Set env vars
@@ -287,13 +281,28 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003
287281
288282 return result
289283
290- @staticmethod
291- @pytest .fixture
292- def benchmark (request : pytest .FixtureRequest ) -> object :
293- if not request .config .getoption ("--codeflash-trace" ):
294- return None
295284
296- return CodeFlashBenchmarkPlugin . Benchmark ( request )
285+ codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin ( )
297286
298287
299- codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin ()
288+ def pytest_configure (config : pytest .Config ) -> None :
289+ """Register the benchmark marker and disable conflicting plugins."""
290+ config .addinivalue_line ("markers" , "benchmark: mark test as a benchmark that should be run with codeflash tracing" )
291+
292+ if config .getoption ("--codeflash-trace" ) and IS_PYTEST_BENCHMARK_INSTALLED :
293+ config .option .benchmark_disable = True
294+ config .pluginmanager .set_blocked ("pytest_benchmark" )
295+ config .pluginmanager .set_blocked ("pytest-benchmark" )
296+
297+
298+ def pytest_addoption (parser : pytest .Parser ) -> None :
299+ parser .addoption (
300+ "--codeflash-trace" , action = "store_true" , default = False , help = "Enable CodeFlash tracing for benchmarks"
301+ )
302+
303+
304+ @pytest .fixture
305+ def benchmark (request : pytest .FixtureRequest ) -> object :
306+ if not request .config .getoption ("--codeflash-trace" ):
307+ return lambda func , * args , ** kwargs : func (* args , ** kwargs )
308+ return codeflash_benchmark_plugin .Benchmark (request )
0 commit comments