Skip to content

Commit b2c2743

Browse files
authored
Merge pull request #586 from codeflash-ai/benchmark-fixture-fix
Benchmark Fixture fixes
2 parents b478f10 + 974b891 commit b2c2743

File tree

16 files changed

+230
-47
lines changed

16 files changed

+230
-47
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""CodeFlash Benchmark - Pytest benchmarking plugin for codeflash.ai."""
2+
3+
__version__ = "0.1.0"
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
import importlib.util
4+
5+
import pytest
6+
7+
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
8+
9+
PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
10+
11+
12+
def pytest_configure(config: pytest.Config) -> None:
13+
"""Register the benchmark marker and disable conflicting plugins."""
14+
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")
15+
16+
if config.getoption("--codeflash-trace") and PYTEST_BENCHMARK_INSTALLED:
17+
config.option.benchmark_disable = True
18+
config.pluginmanager.set_blocked("pytest_benchmark")
19+
config.pluginmanager.set_blocked("pytest-benchmark")
20+
21+
22+
def pytest_addoption(parser: pytest.Parser) -> None:
23+
parser.addoption(
24+
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
25+
)
26+
27+
28+
@pytest.fixture
29+
def benchmark(request: pytest.FixtureRequest) -> object:
30+
"""Benchmark fixture that works with or without pytest-benchmark installed."""
31+
config = request.config
32+
33+
# If --codeflash-trace is enabled, use our implementation
34+
if config.getoption("--codeflash-trace"):
35+
return codeflash_benchmark_plugin.Benchmark(request)
36+
37+
# If pytest-benchmark is installed and --codeflash-trace is not enabled,
38+
# return the normal pytest-benchmark fixture
39+
if PYTEST_BENCHMARK_INSTALLED:
40+
from pytest_benchmark.fixture import BenchmarkFixture as BSF # noqa: N814
41+
42+
bs = getattr(config, "_benchmarksession", None)
43+
if bs and bs.skip:
44+
pytest.skip("Benchmarks are skipped (--benchmark-skip was used).")
45+
46+
node = request.node
47+
marker = node.get_closest_marker("benchmark")
48+
options = dict(marker.kwargs) if marker else {}
49+
50+
if bs:
51+
return BSF(
52+
node,
53+
add_stats=bs.benchmarks.append,
54+
logger=bs.logger,
55+
warner=request.node.warn,
56+
disabled=bs.disabled,
57+
**dict(bs.options, **options),
58+
)
59+
return lambda func, *args, **kwargs: func(*args, **kwargs)
60+
61+
return lambda func, *args, **kwargs: func(*args, **kwargs)

codeflash-benchmark/pyproject.toml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
[project]
2+
name = "codeflash-benchmark"
3+
version = "0.1.0"
4+
description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization"
5+
authors = [{ name = "CodeFlash Inc.", email = "[email protected]" }]
6+
requires-python = ">=3.9"
7+
readme = "README.md"
8+
license = {text = "BSL-1.1"}
9+
keywords = [
10+
"codeflash",
11+
"benchmark",
12+
"pytest",
13+
"performance",
14+
"testing",
15+
]
16+
dependencies = [
17+
"pytest>=7.0.0,!=8.3.4",
18+
]
19+
20+
[project.urls]
21+
Homepage = "https://codeflash.ai"
22+
Repository = "https://github.com/codeflash-ai/codeflash-benchmark"
23+
24+
[project.entry-points.pytest11]
25+
codeflash-benchmark = "codeflash_benchmark.plugin"
26+
27+
[build-system]
28+
requires = ["setuptools>=45", "wheel", "setuptools_scm"]
29+
build-backend = "setuptools.build_meta"
30+
31+
[tool.setuptools]
32+
packages = ["codeflash_benchmark"]

codeflash/benchmarking/plugin/plugin.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
if TYPE_CHECKING:
1717
from codeflash.models.models import BenchmarkKey
1818

19-
IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
19+
PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
2020

2121

2222
class CodeFlashBenchmarkPlugin:
@@ -251,8 +251,12 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
251251

252252
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
253253
"""Actual benchmark implementation."""
254+
node_path = getattr(self.request.node, "path", None) or getattr(self.request.node, "fspath", None)
255+
if node_path is None:
256+
raise RuntimeError("Unable to determine test file path from pytest node")
257+
254258
benchmark_module_path = module_name_from_file_path(
255-
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
259+
Path(str(node_path)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
256260
)
257261

258262
benchmark_function_name = self.request.node.name
@@ -282,26 +286,3 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003
282286

283287

284288
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
285-
286-
287-
def pytest_configure(config: pytest.Config) -> None:
288-
"""Register the benchmark marker and disable conflicting plugins."""
289-
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")
290-
291-
if config.getoption("--codeflash-trace") and IS_PYTEST_BENCHMARK_INSTALLED:
292-
config.option.benchmark_disable = True
293-
config.pluginmanager.set_blocked("pytest_benchmark")
294-
config.pluginmanager.set_blocked("pytest-benchmark")
295-
296-
297-
def pytest_addoption(parser: pytest.Parser) -> None:
298-
parser.addoption(
299-
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
300-
)
301-
302-
303-
@pytest.fixture
304-
def benchmark(request: pytest.FixtureRequest) -> object:
305-
if not request.config.getoption("--codeflash-trace"):
306-
return lambda func, *args, **kwargs: func(*args, **kwargs)
307-
return codeflash_benchmark_plugin.Benchmark(request)

codeflash/benchmarking/pytest_new_process_trace_benchmarks.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
benchmarks_root = sys.argv[1]
88
tests_root = sys.argv[2]
99
trace_file = sys.argv[3]
10-
# current working directory
1110
project_root = Path.cwd()
11+
1212
if __name__ == "__main__":
1313
import pytest
1414

15+
orig_recursion_limit = sys.getrecursionlimit()
16+
sys.setrecursionlimit(orig_recursion_limit * 2)
17+
1518
try:
1619
codeflash_benchmark_plugin.setup(trace_file, project_root)
1720
codeflash_trace.setup(trace_file)
@@ -32,9 +35,12 @@
3235
"addopts=",
3336
],
3437
plugins=[codeflash_benchmark_plugin],
35-
) # Errors will be printed to stdout, not stderr
36-
38+
)
3739
except Exception as e:
3840
print(f"Failed to collect tests: {e!s}", file=sys.stderr)
3941
exitcode = -1
42+
finally:
43+
# Restore the original recursion limit
44+
sys.setrecursionlimit(orig_recursion_limit)
45+
4046
sys.exit(exitcode)

codeflash/benchmarking/replay_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import re
34
import sqlite3
45
import textwrap
56
from pathlib import Path
@@ -14,6 +15,8 @@
1415
if TYPE_CHECKING:
1516
from collections.abc import Generator
1617

18+
benchmark_context_cleaner = re.compile(r"[^a-zA-Z0-9_]+")
19+
1720

1821
def get_next_arg_and_return(
1922
trace_file: str,
@@ -46,6 +49,16 @@ def get_function_alias(module: str, function_name: str) -> str:
4649
return "_".join(module.split(".")) + "_" + function_name
4750

4851

52+
def get_unique_test_name(module: str, function_name: str, benchmark_name: str, class_name: str | None = None) -> str:
53+
clean_benchmark = benchmark_context_cleaner.sub("_", benchmark_name).strip("_")
54+
55+
base_alias = get_function_alias(module, function_name)
56+
if class_name:
57+
class_alias = get_function_alias(module, class_name)
58+
return f"{class_alias}_{function_name}_{clean_benchmark}"
59+
return f"{base_alias}_{clean_benchmark}"
60+
61+
4962
def create_trace_replay_test_code(
5063
trace_file: str,
5164
functions_data: list[dict[str, Any]],
@@ -209,7 +222,8 @@ def create_trace_replay_test_code(
209222
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
210223

211224
test_template += " " if test_framework == "unittest" else ""
212-
test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n"
225+
unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name)
226+
test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n"
213227

214228
return imports + "\n" + metadata + "\n" + test_template
215229

codeflash/discovery/discover_unit_tests.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,13 @@ def __init__(self, function_names_to_find: set[str]) -> None:
150150
self.imported_modules: set[str] = set()
151151
self.has_dynamic_imports: bool = False
152152
self.wildcard_modules: set[str] = set()
153+
# Track aliases: alias_name -> original_name
154+
self.alias_mapping: dict[str, str] = {}
153155

154156
# Precompute function_names for prefix search
155157
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
156158
self._exact_names = function_names_to_find
157-
self._prefix_roots = {}
159+
self._prefix_roots: dict[str, list[str]] = {}
158160
for name in function_names_to_find:
159161
if "." in name:
160162
root = name.split(".", 1)[0]
@@ -206,6 +208,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
206208
imported_name = alias.asname if alias.asname else aname
207209
self.imported_modules.add(imported_name)
208210

211+
if alias.asname:
212+
self.alias_mapping[imported_name] = aname
213+
209214
# Fast check for dynamic import
210215
if mod == "importlib" and aname == "import_module":
211216
self.has_dynamic_imports = True
@@ -222,7 +227,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
222227
self.found_qualified_name = qname
223228
return
224229

225-
# Fast prefix match: only for relevant roots
226230
prefix = qname + "."
227231
# Only bother if one of the targets startswith the prefix-root
228232
candidates = proots.get(qname, ())
@@ -247,6 +251,18 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
247251
self.found_qualified_name = node.attr
248252
return
249253

254+
if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules:
255+
for target_func in self.function_names_to_find:
256+
if "." in target_func:
257+
class_name, method_name = target_func.rsplit(".", 1)
258+
if node.attr == method_name:
259+
imported_name = node.value.id
260+
original_name = self.alias_mapping.get(imported_name, imported_name)
261+
if original_name == class_name:
262+
self.found_any_target_function = True
263+
self.found_qualified_name = target_func
264+
return
265+
250266
# Check if this is accessing a target function through a dynamically imported module
251267
# Only if we've detected dynamic imports are being used
252268
if self.has_dynamic_imports and node.attr in self.function_names_to_find:

codeflash/discovery/functions_to_optimize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def get_functions_to_optimize(
204204
functions[file] = [found_function]
205205
else:
206206
logger.info("Finding all functions modified in the current git diff ...")
207+
console.rule()
207208
ph("cli-optimizing-git-diff")
208209
functions = get_functions_within_git_diff()
209210
filtered_modified_functions, functions_count = filter_functions(

codeflash/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,7 @@ def group_by_benchmarks(
515515
benchmark_replay_test_dir.resolve()
516516
/ f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_",
517517
project_root,
518+
traverse_up=True,
518519
)
519520
for test_result in self.test_results:
520521
if test_result.test_type == TestType.REPLAY_TEST:

codeflash/optimization/function_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
840840
f"{concolic_coverage_test_files_count} concolic coverage test file"
841841
f"{'s' if concolic_coverage_test_files_count != 1 else ''} for {func_qualname}"
842842
)
843+
console.rule()
843844
return unique_instrumented_test_files
844845

845846
def generate_tests_and_optimizations(

0 commit comments

Comments
 (0)