Skip to content

Commit b598916

Browse files
authored
Merge branch 'main' into saga4/FD_leak
2 parents 1b3c83c + 34c3aa9 commit b598916

File tree

7 files changed

+61
-57
lines changed

7 files changed

+61
-57
lines changed

.github/workflows/unit-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ jobs:
3030
run: uv sync
3131

3232
- name: Unit tests
33-
run: uv run pytest tests/ --benchmark-skip -m "not ci_skip"
33+
run: uv run pytest tests/

codeflash/benchmarking/plugin/plugin.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from __future__ import annotations
22

3+
import importlib.util
34
import os
45
import sqlite3
56
import sys
67
import time
78
from pathlib import Path
9+
from typing import TYPE_CHECKING
810

911
import pytest
1012

1113
from codeflash.benchmarking.codeflash_trace import codeflash_trace
1214
from codeflash.code_utils.code_utils import module_name_from_file_path
13-
from codeflash.models.models import BenchmarkKey
15+
16+
if TYPE_CHECKING:
17+
from codeflash.models.models import BenchmarkKey
18+
19+
IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
1420

1521

1622
class CodeFlashBenchmarkPlugin:
@@ -71,6 +77,8 @@ def close(self) -> None:
7177

7278
@staticmethod
7379
def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]:
80+
from codeflash.models.models import BenchmarkKey
81+
7482
"""Process the trace file and extract timing data for all functions.
7583
7684
Args:
@@ -131,6 +139,8 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
131139

132140
@staticmethod
133141
def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
142+
from codeflash.models.models import BenchmarkKey
143+
134144
"""Extract total benchmark timings from trace files.
135145
136146
Args:
@@ -199,23 +209,6 @@ def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, AR
199209
# Close the database connection
200210
self.close()
201211

202-
@staticmethod
203-
def pytest_addoption(parser: pytest.Parser) -> None:
204-
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
205-
206-
@staticmethod
207-
def pytest_plugin_registered(plugin, manager) -> None: # noqa: ANN001
208-
# Not necessary since run with -p no:benchmark, but just in case
209-
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
210-
manager.unregister(plugin)
211-
212-
@staticmethod
213-
def pytest_configure(config: pytest.Config) -> None:
214-
"""Register the benchmark marker."""
215-
config.addinivalue_line(
216-
"markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing"
217-
)
218-
219212
@staticmethod
220213
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
221214
# Skip tests that don't have the benchmark fixture
@@ -259,8 +252,9 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
259252
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
260253
"""Actual benchmark implementation."""
261254
benchmark_module_path = module_name_from_file_path(
262-
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
255+
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
263256
)
257+
264258
benchmark_function_name = self.request.node.name
265259
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack # noqa: SLF001
266260
# Set env vars
@@ -286,13 +280,28 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003
286280

287281
return result
288282

289-
@staticmethod
290-
@pytest.fixture
291-
def benchmark(request: pytest.FixtureRequest) -> object:
292-
if not request.config.getoption("--codeflash-trace"):
293-
return None
294283

295-
return CodeFlashBenchmarkPlugin.Benchmark(request)
284+
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
296285

297286

298-
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
287+
def pytest_configure(config: pytest.Config) -> None:
288+
"""Register the benchmark marker and disable conflicting plugins."""
289+
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")
290+
291+
if config.getoption("--codeflash-trace") and IS_PYTEST_BENCHMARK_INSTALLED:
292+
config.option.benchmark_disable = True
293+
config.pluginmanager.set_blocked("pytest_benchmark")
294+
config.pluginmanager.set_blocked("pytest-benchmark")
295+
296+
297+
def pytest_addoption(parser: pytest.Parser) -> None:
298+
parser.addoption(
299+
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
300+
)
301+
302+
303+
@pytest.fixture
304+
def benchmark(request: pytest.FixtureRequest) -> object:
305+
if not request.config.getoption("--codeflash-trace"):
306+
return lambda func, *args, **kwargs: func(*args, **kwargs)
307+
return codeflash_benchmark_plugin.Benchmark(request)

codeflash/code_utils/code_utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,21 @@ def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
109109
return full_qualified_name[len(module_name) + 1 :]
110110

111111

112-
def module_name_from_file_path(file_path: Path, project_root_path: Path) -> str:
113-
relative_path = file_path.relative_to(project_root_path)
114-
return relative_path.with_suffix("").as_posix().replace("/", ".")
112+
def module_name_from_file_path(file_path: Path, project_root_path: Path, *, traverse_up: bool = False) -> str:
113+
try:
114+
relative_path = file_path.relative_to(project_root_path)
115+
return relative_path.with_suffix("").as_posix().replace("/", ".")
116+
except ValueError:
117+
if traverse_up:
118+
parent = file_path.parent
119+
while parent not in (project_root_path, parent.parent):
120+
try:
121+
relative_path = file_path.relative_to(parent)
122+
return relative_path.with_suffix("").as_posix().replace("/", ".")
123+
except ValueError:
124+
parent = parent.parent
125+
msg = f"File {file_path} is not within the project root {project_root_path}."
126+
raise ValueError(msg) # noqa: B904
115127

116128

117129
def file_path_from_module_name(module_name: str, project_root_path: Path) -> Path:

codeflash/optimization/optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def run_benchmarks(
6565
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
6666
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
6767

68+
console.rule()
6869
with progress_bar(
6970
f"Running benchmarks in {self.args.benchmarks_root}", transient=True, revert_to_print=bool(get_pr_number())
7071
):

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ dev = [
6767
"types-openpyxl>=3.1.5.20241020",
6868
"types-regex>=2024.9.11.20240912",
6969
"types-python-dateutil>=2.9.0.20241003",
70-
"pytest-benchmark>=5.1.0",
7170
"types-gevent>=24.11.0.20241230,<25",
7271
"types-greenlet>=3.1.0.20241221,<4",
7372
"types-pexpect>=4.9.0.20241208,<5",
@@ -300,3 +299,6 @@ markers = [
300299
[build-system]
301300
requires = ["hatchling", "uv-dynamic-versioning"]
302301
build-backend = "hatchling.build"
302+
303+
[project.entry-points.pytest11]
304+
codeflash = "codeflash.benchmarking.plugin.plugin"

tests/test_trace_benchmarks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_trace_benchmarks() -> None:
3333
function_calls = cursor.fetchall()
3434

3535
# Assert the length of function calls
36-
assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}"
36+
assert len(function_calls) == 7, f"Expected 7 function calls, but got {len(function_calls)}"
3737

3838
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
3939
process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix()

uv.lock

Lines changed: 5 additions & 25 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)