Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ jobs:
run: uv sync

- name: Unit tests
run: uv run pytest tests/ --benchmark-skip -m "not ci_skip"
run: uv run pytest tests/
61 changes: 35 additions & 26 deletions codeflash/benchmarking/plugin/plugin.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from __future__ import annotations

import importlib.util
import os
import sqlite3
import sys
import time
from pathlib import Path
from typing import TYPE_CHECKING

import pytest

from codeflash.benchmarking.codeflash_trace import codeflash_trace
from codeflash.code_utils.code_utils import module_name_from_file_path
from codeflash.models.models import BenchmarkKey

if TYPE_CHECKING:
from codeflash.models.models import BenchmarkKey

IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None


class CodeFlashBenchmarkPlugin:
Expand Down Expand Up @@ -71,6 +77,8 @@ def close(self) -> None:

@staticmethod
def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]:
from codeflash.models.models import BenchmarkKey

"""Process the trace file and extract timing data for all functions.

Args:
Expand Down Expand Up @@ -131,6 +139,8 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark

@staticmethod
def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
from codeflash.models.models import BenchmarkKey

"""Extract total benchmark timings from trace files.

Args:
Expand Down Expand Up @@ -199,23 +209,6 @@ def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, AR
# Close the database connection
self.close()

@staticmethod
def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")

@staticmethod
def pytest_plugin_registered(plugin, manager) -> None: # noqa: ANN001
# Not necessary since run with -p no:benchmark, but just in case
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
manager.unregister(plugin)

@staticmethod
def pytest_configure(config: pytest.Config) -> None:
"""Register the benchmark marker."""
config.addinivalue_line(
"markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing"
)

@staticmethod
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
# Skip tests that don't have the benchmark fixture
Expand Down Expand Up @@ -259,8 +252,9 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
"""Actual benchmark implementation."""
benchmark_module_path = module_name_from_file_path(
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
)

benchmark_function_name = self.request.node.name
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack # noqa: SLF001
# Set env vars
Expand All @@ -286,13 +280,28 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003

return result

@staticmethod
@pytest.fixture
def benchmark(request: pytest.FixtureRequest) -> object:
if not request.config.getoption("--codeflash-trace"):
return None

return CodeFlashBenchmarkPlugin.Benchmark(request)
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()


codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
def pytest_configure(config: pytest.Config) -> None:
"""Register the benchmark marker and disable conflicting plugins."""
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")

if config.getoption("--codeflash-trace") and IS_PYTEST_BENCHMARK_INSTALLED:
config.option.benchmark_disable = True
config.pluginmanager.set_blocked("pytest_benchmark")
config.pluginmanager.set_blocked("pytest-benchmark")


def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption(
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
)


@pytest.fixture
def benchmark(request: pytest.FixtureRequest) -> object:
if not request.config.getoption("--codeflash-trace"):
return lambda func, *args, **kwargs: func(*args, **kwargs)
return codeflash_benchmark_plugin.Benchmark(request)
18 changes: 15 additions & 3 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,21 @@ def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
return full_qualified_name[len(module_name) + 1 :]


def module_name_from_file_path(file_path: Path, project_root_path: Path) -> str:
relative_path = file_path.relative_to(project_root_path)
return relative_path.with_suffix("").as_posix().replace("/", ".")
def module_name_from_file_path(file_path: Path, project_root_path: Path, *, traverse_up: bool = False) -> str:
try:
relative_path = file_path.relative_to(project_root_path)
return relative_path.with_suffix("").as_posix().replace("/", ".")
except ValueError:
if traverse_up:
parent = file_path.parent
while parent not in (project_root_path, parent.parent):
try:
relative_path = file_path.relative_to(parent)
return relative_path.with_suffix("").as_posix().replace("/", ".")
except ValueError:
parent = parent.parent
msg = f"File {file_path} is not within the project root {project_root_path}."
raise ValueError(msg) # noqa: B904


def file_path_from_module_name(module_name: str, project_root_path: Path) -> Path:
Expand Down
1 change: 1 addition & 0 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def run_benchmarks(
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table

console.rule()
with progress_bar(
f"Running benchmarks in {self.args.benchmarks_root}", transient=True, revert_to_print=bool(get_pr_number())
):
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ dev = [
"types-openpyxl>=3.1.5.20241020",
"types-regex>=2024.9.11.20240912",
"types-python-dateutil>=2.9.0.20241003",
"pytest-benchmark>=5.1.0",
"types-gevent>=24.11.0.20241230,<25",
"types-greenlet>=3.1.0.20241221,<4",
"types-pexpect>=4.9.0.20241208,<5",
Expand Down Expand Up @@ -300,3 +299,6 @@ markers = [
[build-system]
requires = ["hatchling", "uv-dynamic-versioning"]
build-backend = "hatchling.build"

[project.entry-points.pytest11]
codeflash = "codeflash.benchmarking.plugin.plugin"
2 changes: 1 addition & 1 deletion tests/test_trace_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_trace_benchmarks() -> None:
function_calls = cursor.fetchall()

# Assert the length of function calls
assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}"
assert len(function_calls) == 7, f"Expected 7 function calls, but got {len(function_calls)}"

bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix()
Expand Down
30 changes: 5 additions & 25 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading