11from __future__ import annotations
22
3- import importlib .util
43import os
54import sqlite3
65import sys
76import time
87from pathlib import Path
9- from typing import TYPE_CHECKING
108
119import pytest
1210
1311from codeflash .benchmarking .codeflash_trace import codeflash_trace
1412from codeflash .code_utils .code_utils import module_name_from_file_path
15-
16- if TYPE_CHECKING :
17- from codeflash .models .models import BenchmarkKey
18-
19- IS_PYTEST_BENCHMARK_INSTALLED = importlib .util .find_spec ("pytest_benchmark" ) is not None
13+ from codeflash .models .models import BenchmarkKey
2014
2115
2216class CodeFlashBenchmarkPlugin :
@@ -77,8 +71,6 @@ def close(self) -> None:
7771
7872 @staticmethod
7973 def get_function_benchmark_timings (trace_path : Path ) -> dict [str , dict [BenchmarkKey , int ]]:
80- from codeflash .models .models import BenchmarkKey
81-
8274 """Process the trace file and extract timing data for all functions.
8375
8476 Args:
@@ -139,8 +131,6 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
139131
140132 @staticmethod
141133 def get_benchmark_timings (trace_path : Path ) -> dict [BenchmarkKey , int ]:
142- from codeflash .models .models import BenchmarkKey
143-
144134 """Extract total benchmark timings from trace files.
145135
146136 Args:
@@ -209,6 +199,23 @@ def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, AR
209199 # Close the database connection
210200 self .close ()
211201
202+ @staticmethod
203+ def pytest_addoption (parser : pytest .Parser ) -> None :
204+ parser .addoption ("--codeflash-trace" , action = "store_true" , default = False , help = "Enable CodeFlash tracing" )
205+
206+ @staticmethod
207+ def pytest_plugin_registered (plugin , manager ) -> None : # noqa: ANN001
208+ # Not necessary since run with -p no:benchmark, but just in case
209+ if hasattr (plugin , "name" ) and plugin .name == "pytest-benchmark" :
210+ manager .unregister (plugin )
211+
212+ @staticmethod
213+ def pytest_configure (config : pytest .Config ) -> None :
214+ """Register the benchmark marker."""
215+ config .addinivalue_line (
216+ "markers" , "benchmark: mark test as a benchmark that should be run with codeflash tracing"
217+ )
218+
212219 @staticmethod
213220 def pytest_collection_modifyitems (config : pytest .Config , items : list [pytest .Item ]) -> None :
214221 # Skip tests that don't have the benchmark fixture
@@ -252,9 +259,8 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
252259 def _run_benchmark (self , func , * args , ** kwargs ): # noqa: ANN001, ANN002, ANN003, ANN202
253260 """Actual benchmark implementation."""
254261 benchmark_module_path = module_name_from_file_path (
255- Path (str (self .request .node .fspath )), Path (codeflash_benchmark_plugin .project_root ), traverse_up = True
262+ Path (str (self .request .node .fspath )), Path (codeflash_benchmark_plugin .project_root )
256263 )
257-
258264 benchmark_function_name = self .request .node .name
259265 line_number = int (str (sys ._getframe (2 ).f_lineno )) # 2 frames up in the call stack # noqa: SLF001
260266 # Set env vars
@@ -280,28 +286,13 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003
280286
281287 return result
282288
289+ @staticmethod
290+ @pytest .fixture
291+ def benchmark (request : pytest .FixtureRequest ) -> object :
292+ if not request .config .getoption ("--codeflash-trace" ):
293+ return None
283294
284- codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin ()
285-
286-
287- def pytest_configure (config : pytest .Config ) -> None :
288- """Register the benchmark marker and disable conflicting plugins."""
289- config .addinivalue_line ("markers" , "benchmark: mark test as a benchmark that should be run with codeflash tracing" )
290-
291- if config .getoption ("--codeflash-trace" ) and IS_PYTEST_BENCHMARK_INSTALLED :
292- config .option .benchmark_disable = True
293- config .pluginmanager .set_blocked ("pytest_benchmark" )
294- config .pluginmanager .set_blocked ("pytest-benchmark" )
295-
296-
297- def pytest_addoption (parser : pytest .Parser ) -> None :
298- parser .addoption (
299- "--codeflash-trace" , action = "store_true" , default = False , help = "Enable CodeFlash tracing for benchmarks"
300- )
295+ return CodeFlashBenchmarkPlugin .Benchmark (request )
301296
302297
303- @pytest .fixture
304- def benchmark (request : pytest .FixtureRequest ) -> object :
305- if not request .config .getoption ("--codeflash-trace" ):
306- return lambda func , * args , ** kwargs : func (* args , ** kwargs )
307- return codeflash_benchmark_plugin .Benchmark (request )
298+ codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin ()
0 commit comments