Skip to content

Commit 3690f07

Browse files
committed
Update plugin.py
1 parent efa9c68 commit 3690f07

File tree

1 file changed

+26
-35
lines changed

1 file changed

+26
-35
lines changed

codeflash/benchmarking/plugin/plugin.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
from __future__ import annotations
22

3-
import importlib.util
43
import os
54
import sqlite3
65
import sys
76
import time
87
from pathlib import Path
9-
from typing import TYPE_CHECKING
108

119
import pytest
1210

1311
from codeflash.benchmarking.codeflash_trace import codeflash_trace
1412
from codeflash.code_utils.code_utils import module_name_from_file_path
15-
16-
if TYPE_CHECKING:
17-
from codeflash.models.models import BenchmarkKey
18-
19-
IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
13+
from codeflash.models.models import BenchmarkKey
2014

2115

2216
class CodeFlashBenchmarkPlugin:
@@ -77,8 +71,6 @@ def close(self) -> None:
7771

7872
@staticmethod
7973
def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]:
80-
from codeflash.models.models import BenchmarkKey
81-
8274
"""Process the trace file and extract timing data for all functions.
8375
8476
Args:
@@ -139,8 +131,6 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
139131

140132
@staticmethod
141133
def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
142-
from codeflash.models.models import BenchmarkKey
143-
144134
"""Extract total benchmark timings from trace files.
145135
146136
Args:
@@ -209,6 +199,23 @@ def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, AR
209199
# Close the database connection
210200
self.close()
211201

202+
@staticmethod
203+
def pytest_addoption(parser: pytest.Parser) -> None:
204+
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
205+
206+
@staticmethod
207+
def pytest_plugin_registered(plugin, manager) -> None: # noqa: ANN001
208+
# Not necessary since run with -p no:benchmark, but just in case
209+
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
210+
manager.unregister(plugin)
211+
212+
@staticmethod
213+
def pytest_configure(config: pytest.Config) -> None:
214+
"""Register the benchmark marker."""
215+
config.addinivalue_line(
216+
"markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing"
217+
)
218+
212219
@staticmethod
213220
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
214221
# Skip tests that don't have the benchmark fixture
@@ -252,9 +259,8 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
252259
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
253260
"""Actual benchmark implementation."""
254261
benchmark_module_path = module_name_from_file_path(
255-
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
262+
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
256263
)
257-
258264
benchmark_function_name = self.request.node.name
259265
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack # noqa: SLF001
260266
# Set env vars
@@ -280,28 +286,13 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003
280286

281287
return result
282288

289+
@staticmethod
290+
@pytest.fixture
291+
def benchmark(request: pytest.FixtureRequest) -> object:
292+
if not request.config.getoption("--codeflash-trace"):
293+
return None
283294

284-
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
285-
286-
287-
def pytest_configure(config: pytest.Config) -> None:
288-
"""Register the benchmark marker and disable conflicting plugins."""
289-
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")
290-
291-
if config.getoption("--codeflash-trace") and IS_PYTEST_BENCHMARK_INSTALLED:
292-
config.option.benchmark_disable = True
293-
config.pluginmanager.set_blocked("pytest_benchmark")
294-
config.pluginmanager.set_blocked("pytest-benchmark")
295-
296-
297-
def pytest_addoption(parser: pytest.Parser) -> None:
298-
parser.addoption(
299-
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
300-
)
295+
return CodeFlashBenchmarkPlugin.Benchmark(request)
301296

302297

303-
@pytest.fixture
304-
def benchmark(request: pytest.FixtureRequest) -> object:
305-
if not request.config.getoption("--codeflash-trace"):
306-
return lambda func, *args, **kwargs: func(*args, **kwargs)
307-
return codeflash_benchmark_plugin.Benchmark(request)
298+
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()

0 commit comments

Comments
 (0)