Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions codeflash-benchmark/codeflash_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""CodeFlash Benchmark - Pytest benchmarking plugin for codeflash.ai."""

__version__ = "0.1.0"
61 changes: 61 additions & 0 deletions codeflash-benchmark/codeflash_benchmark/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

import importlib.util

import pytest

from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin

PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None


def pytest_configure(config: pytest.Config) -> None:
"""Register the benchmark marker and disable conflicting plugins."""
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")

if config.getoption("--codeflash-trace") and PYTEST_BENCHMARK_INSTALLED:
config.option.benchmark_disable = True
config.pluginmanager.set_blocked("pytest_benchmark")
config.pluginmanager.set_blocked("pytest-benchmark")


def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption(
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
)


@pytest.fixture
def benchmark(request: pytest.FixtureRequest) -> object:
"""Benchmark fixture that works with or without pytest-benchmark installed."""
config = request.config

# If --codeflash-trace is enabled, use our implementation
if config.getoption("--codeflash-trace"):
return codeflash_benchmark_plugin.Benchmark(request)

# If pytest-benchmark is installed and --codeflash-trace is not enabled,
# return the normal pytest-benchmark fixture
if PYTEST_BENCHMARK_INSTALLED:
from pytest_benchmark.fixture import BenchmarkFixture as BSF # noqa: N814

bs = getattr(config, "_benchmarksession", None)
if bs and bs.skip:
pytest.skip("Benchmarks are skipped (--benchmark-skip was used).")

node = request.node
marker = node.get_closest_marker("benchmark")
options = dict(marker.kwargs) if marker else {}

if bs:
return BSF(
node,
add_stats=bs.benchmarks.append,
logger=bs.logger,
warner=request.node.warn,
disabled=bs.disabled,
**dict(bs.options, **options),
)
return lambda func, *args, **kwargs: func(*args, **kwargs)

return lambda func, *args, **kwargs: func(*args, **kwargs)
32 changes: 32 additions & 0 deletions codeflash-benchmark/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
[project]
name = "codeflash-benchmark"
version = "0.1.0"
description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization"
authors = [{ name = "CodeFlash Inc.", email = "[email protected]" }]
requires-python = ">=3.9"
readme = "README.md"
license = {text = "BSL-1.1"}
keywords = [
"codeflash",
"benchmark",
"pytest",
"performance",
"testing",
]
dependencies = [
"pytest>=7.0.0,!=8.3.4",
]

[project.urls]
Homepage = "https://codeflash.ai"
Repository = "https://github.com/codeflash-ai/codeflash-benchmark"

[project.entry-points.pytest11]
codeflash-benchmark = "codeflash_benchmark.plugin"

[build-system]
requires = ["setuptools>=45", "wheel", "setuptools_scm"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
packages = ["codeflash_benchmark"]
31 changes: 6 additions & 25 deletions codeflash/benchmarking/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
if TYPE_CHECKING:
from codeflash.models.models import BenchmarkKey

IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None


class CodeFlashBenchmarkPlugin:
Expand Down Expand Up @@ -251,8 +251,12 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202

def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
"""Actual benchmark implementation."""
node_path = getattr(self.request.node, "path", None) or getattr(self.request.node, "fspath", None)
if node_path is None:
raise RuntimeError("Unable to determine test file path from pytest node")

benchmark_module_path = module_name_from_file_path(
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
Path(str(node_path)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
)

benchmark_function_name = self.request.node.name
Expand Down Expand Up @@ -282,26 +286,3 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003


codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()


def pytest_configure(config: pytest.Config) -> None:
"""Register the benchmark marker and disable conflicting plugins."""
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")

if config.getoption("--codeflash-trace") and IS_PYTEST_BENCHMARK_INSTALLED:
config.option.benchmark_disable = True
config.pluginmanager.set_blocked("pytest_benchmark")
config.pluginmanager.set_blocked("pytest-benchmark")


def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption(
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
)


@pytest.fixture
def benchmark(request: pytest.FixtureRequest) -> object:
if not request.config.getoption("--codeflash-trace"):
return lambda func, *args, **kwargs: func(*args, **kwargs)
return codeflash_benchmark_plugin.Benchmark(request)
12 changes: 9 additions & 3 deletions codeflash/benchmarking/pytest_new_process_trace_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
benchmarks_root = sys.argv[1]
tests_root = sys.argv[2]
trace_file = sys.argv[3]
# current working directory
project_root = Path.cwd()

if __name__ == "__main__":
import pytest

orig_recursion_limit = sys.getrecursionlimit()
sys.setrecursionlimit(orig_recursion_limit * 2)

try:
codeflash_benchmark_plugin.setup(trace_file, project_root)
codeflash_trace.setup(trace_file)
Expand All @@ -32,9 +35,12 @@
"addopts=",
],
plugins=[codeflash_benchmark_plugin],
) # Errors will be printed to stdout, not stderr

)
except Exception as e:
print(f"Failed to collect tests: {e!s}", file=sys.stderr)
exitcode = -1
finally:
# Restore the original recursion limit
sys.setrecursionlimit(orig_recursion_limit)

sys.exit(exitcode)
16 changes: 15 additions & 1 deletion codeflash/benchmarking/replay_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
import sqlite3
import textwrap
from pathlib import Path
Expand All @@ -14,6 +15,8 @@
if TYPE_CHECKING:
from collections.abc import Generator

benchmark_context_cleaner = re.compile(r"[^a-zA-Z0-9_]+")


def get_next_arg_and_return(
trace_file: str,
Expand Down Expand Up @@ -46,6 +49,16 @@ def get_function_alias(module: str, function_name: str) -> str:
return "_".join(module.split(".")) + "_" + function_name


def get_unique_test_name(module: str, function_name: str, benchmark_name: str, class_name: str | None = None) -> str:
clean_benchmark = benchmark_context_cleaner.sub("_", benchmark_name).strip("_")

base_alias = get_function_alias(module, function_name)
if class_name:
class_alias = get_function_alias(module, class_name)
return f"{class_alias}_{function_name}_{clean_benchmark}"
return f"{base_alias}_{clean_benchmark}"


def create_trace_replay_test_code(
trace_file: str,
functions_data: list[dict[str, Any]],
Expand Down Expand Up @@ -209,7 +222,8 @@ def create_trace_replay_test_code(
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")

test_template += " " if test_framework == "unittest" else ""
test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n"
unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name)
test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n"

return imports + "\n" + metadata + "\n" + test_template

Expand Down
20 changes: 18 additions & 2 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,13 @@ def __init__(self, function_names_to_find: set[str]) -> None:
self.imported_modules: set[str] = set()
self.has_dynamic_imports: bool = False
self.wildcard_modules: set[str] = set()
# Track aliases: alias_name -> original_name
self.alias_mapping: dict[str, str] = {}

# Precompute function_names for prefix search
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
self._exact_names = function_names_to_find
self._prefix_roots = {}
self._prefix_roots: dict[str, list[str]] = {}
for name in function_names_to_find:
if "." in name:
root = name.split(".", 1)[0]
Expand Down Expand Up @@ -206,6 +208,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
imported_name = alias.asname if alias.asname else aname
self.imported_modules.add(imported_name)

if alias.asname:
self.alias_mapping[imported_name] = aname

# Fast check for dynamic import
if mod == "importlib" and aname == "import_module":
self.has_dynamic_imports = True
Expand All @@ -222,7 +227,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self.found_qualified_name = qname
return

# Fast prefix match: only for relevant roots
prefix = qname + "."
# Only bother if one of the targets startswith the prefix-root
candidates = proots.get(qname, ())
Expand All @@ -247,6 +251,18 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
self.found_qualified_name = node.attr
return

if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules:
for target_func in self.function_names_to_find:
if "." in target_func:
class_name, method_name = target_func.rsplit(".", 1)
if node.attr == method_name:
imported_name = node.value.id
original_name = self.alias_mapping.get(imported_name, imported_name)
if original_name == class_name:
self.found_any_target_function = True
self.found_qualified_name = target_func
return

# Check if this is accessing a target function through a dynamically imported module
# Only if we've detected dynamic imports are being used
if self.has_dynamic_imports and node.attr in self.function_names_to_find:
Expand Down
1 change: 1 addition & 0 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def get_functions_to_optimize(
functions[file] = [found_function]
else:
logger.info("Finding all functions modified in the current git diff ...")
console.rule()
ph("cli-optimizing-git-diff")
functions = get_functions_within_git_diff()
filtered_modified_functions, functions_count = filter_functions(
Expand Down
1 change: 1 addition & 0 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ def group_by_benchmarks(
benchmark_replay_test_dir.resolve()
/ f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_",
project_root,
traverse_up=True,
)
for test_result in self.test_results:
if test_result.test_type == TestType.REPLAY_TEST:
Expand Down
1 change: 1 addition & 0 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
f"{concolic_coverage_test_files_count} concolic coverage test file"
f"{'s' if concolic_coverage_test_files_count != 1 else ''} for {func_qualname}"
)
console.rule()
return unique_instrumented_test_files

def generate_tests_and_optimizations(
Expand Down
2 changes: 1 addition & 1 deletion codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def run_benchmarks(
for file in file_path_to_source_code:
with file.open("w", encoding="utf8") as f:
f.write(file_path_to_source_code[file])

console.rule()
return function_benchmark_timings, total_benchmark_timings

def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
Expand Down
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"line_profiler>=4.2.0",
"platformdirs>=4.3.7",
"pygls>=1.3.1",
"codeflash-benchmark",
]

[project.urls]
Expand Down Expand Up @@ -262,6 +263,12 @@ skip-magic-trailing-comma = true
[tool.hatch.version]
source = "uv-dynamic-versioning"

[tool.uv]
workspace = { members = ["codeflash-benchmark"] }

[tool.uv.sources]
codeflash-benchmark = { workspace = true }

[tool.uv-dynamic-versioning]
enable = true
style = "pep440"
Expand Down Expand Up @@ -300,5 +307,3 @@ markers = [
requires = ["hatchling", "uv-dynamic-versioning"]
build-backend = "hatchling.build"

[project.entry-points.pytest11]
codeflash = "codeflash.benchmarking.plugin.plugin"
10 changes: 5 additions & 5 deletions tests/test_pickle_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def test_run_and_parse_picklepatch() -> None:
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_type = TestType.REPLAY_TEST
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket"
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
func_optimizer = opt.create_function_optimizer(func)
func_optimizer.test_files = TestFiles(
test_files=[
Expand All @@ -388,7 +388,7 @@ def test_run_and_parse_picklepatch() -> None:
)
assert len(test_results_unused_socket) == 1
assert test_results_unused_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket"
assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
assert test_results_unused_socket.test_results[0].did_pass == True

# Replace with optimized candidate
Expand Down Expand Up @@ -439,7 +439,7 @@ def bubble_sort_with_unused_socket(data_container):
test_type = TestType.REPLAY_TEST
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[],
file_path=Path(fto_used_socket_path))
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
func_optimizer = opt.create_function_optimizer(func)
func_optimizer.test_files = TestFiles(
test_files=[
Expand Down Expand Up @@ -467,7 +467,7 @@ def bubble_sort_with_unused_socket(data_container):
assert test_results_used_socket.test_results[
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
assert test_results_used_socket.test_results[
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
assert test_results_used_socket.test_results[0].did_pass is False
print("test results used socket")
print(test_results_used_socket)
Expand Down Expand Up @@ -498,7 +498,7 @@ def bubble_sort_with_used_socket(data_container):
assert test_results_used_socket.test_results[
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
assert test_results_used_socket.test_results[
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket_test_used_socket_picklepatch"
assert test_results_used_socket.test_results[0].did_pass is False

# Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined.
Expand Down
Loading
Loading