Skip to content

Commit 00d78a9

Browse files
committed
Merge branch 'main' of https://github.com/codeflash-ai/codeflash into part-1-windows-fixes
2 parents fe060f8 + d5ec766 commit 00d78a9

File tree

22 files changed

+577
-144
lines changed

22 files changed

+577
-144
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""CodeFlash Benchmark - Pytest benchmarking plugin for codeflash.ai."""
2+
3+
__version__ = "0.1.0"
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
import importlib.util
4+
5+
import pytest
6+
7+
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
8+
9+
PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
10+
11+
12+
def pytest_configure(config: pytest.Config) -> None:
13+
"""Register the benchmark marker and disable conflicting plugins."""
14+
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")
15+
16+
if config.getoption("--codeflash-trace") and PYTEST_BENCHMARK_INSTALLED:
17+
config.option.benchmark_disable = True
18+
config.pluginmanager.set_blocked("pytest_benchmark")
19+
config.pluginmanager.set_blocked("pytest-benchmark")
20+
21+
22+
def pytest_addoption(parser: pytest.Parser) -> None:
23+
parser.addoption(
24+
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
25+
)
26+
27+
28+
@pytest.fixture
29+
def benchmark(request: pytest.FixtureRequest) -> object:
30+
"""Benchmark fixture that works with or without pytest-benchmark installed."""
31+
config = request.config
32+
33+
# If --codeflash-trace is enabled, use our implementation
34+
if config.getoption("--codeflash-trace"):
35+
return codeflash_benchmark_plugin.Benchmark(request)
36+
37+
# If pytest-benchmark is installed and --codeflash-trace is not enabled,
38+
# return the normal pytest-benchmark fixture
39+
if PYTEST_BENCHMARK_INSTALLED:
40+
from pytest_benchmark.fixture import BenchmarkFixture as BSF # noqa: N814
41+
42+
bs = getattr(config, "_benchmarksession", None)
43+
if bs and bs.skip:
44+
pytest.skip("Benchmarks are skipped (--benchmark-skip was used).")
45+
46+
node = request.node
47+
marker = node.get_closest_marker("benchmark")
48+
options = dict(marker.kwargs) if marker else {}
49+
50+
if bs:
51+
return BSF(
52+
node,
53+
add_stats=bs.benchmarks.append,
54+
logger=bs.logger,
55+
warner=request.node.warn,
56+
disabled=bs.disabled,
57+
**dict(bs.options, **options),
58+
)
59+
return lambda func, *args, **kwargs: func(*args, **kwargs)
60+
61+
return lambda func, *args, **kwargs: func(*args, **kwargs)

codeflash-benchmark/pyproject.toml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
[project]
2+
name = "codeflash-benchmark"
3+
version = "0.1.0"
4+
description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization"
5+
authors = [{ name = "CodeFlash Inc.", email = "[email protected]" }]
6+
requires-python = ">=3.9"
7+
readme = "README.md"
8+
license = {text = "BSL-1.1"}
9+
keywords = [
10+
"codeflash",
11+
"benchmark",
12+
"pytest",
13+
"performance",
14+
"testing",
15+
]
16+
dependencies = [
17+
"pytest>=7.0.0,!=8.3.4",
18+
]
19+
20+
[project.urls]
21+
Homepage = "https://codeflash.ai"
22+
Repository = "https://github.com/codeflash-ai/codeflash-benchmark"
23+
24+
[project.entry-points.pytest11]
25+
codeflash-benchmark = "codeflash_benchmark.plugin"
26+
27+
[build-system]
28+
requires = ["setuptools>=45", "wheel", "setuptools_scm"]
29+
build-backend = "setuptools.build_meta"
30+
31+
[tool.setuptools]
32+
packages = ["codeflash_benchmark"]

codeflash/api/aiservice.py

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from codeflash.cli_cmds.console import console, logger
1313
from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled
1414
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
15-
from codeflash.models.models import OptimizedCandidate
15+
from codeflash.models.ExperimentMetadata import ExperimentMetadata
16+
from codeflash.models.models import AIServiceRefinerRequest, OptimizedCandidate
1617
from codeflash.telemetry.posthog_cf import ph
1718
from codeflash.version import __version__ as codeflash_version
1819

@@ -21,6 +22,7 @@
2122

2223
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
2324
from codeflash.models.ExperimentMetadata import ExperimentMetadata
25+
from codeflash.models.models import AIServiceRefinerRequest
2426

2527

2628
class AiServiceClient:
@@ -36,7 +38,11 @@ def get_aiservice_base_url(self) -> str:
3638
return "https://app.codeflash.ai"
3739

3840
def make_ai_service_request(
39-
self, endpoint: str, method: str = "POST", payload: dict[str, Any] | None = None, timeout: float | None = None
41+
self,
42+
endpoint: str,
43+
method: str = "POST",
44+
payload: dict[str, Any] | list[dict[str, Any]] | None = None,
45+
timeout: float | None = None,
4046
) -> requests.Response:
4147
"""Make an API request to the given endpoint on the AI service.
4248
@@ -98,11 +104,7 @@ def optimize_python_code( # noqa: D417
98104
99105
"""
100106
start_time = time.perf_counter()
101-
try:
102-
git_repo_owner, git_repo_name = get_repo_owner_and_name()
103-
except Exception as e:
104-
logger.warning(f"Could not determine repo owner and name: {e}")
105-
git_repo_owner, git_repo_name = None, None
107+
git_repo_owner, git_repo_name = safe_get_repo_owner_and_name()
106108

107109
payload = {
108110
"source_code": source_code,
@@ -219,13 +221,72 @@ def optimize_python_code_line_profiler( # noqa: D417
219221
console.rule()
220222
return []
221223

224+
def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
225+
"""Optimize the given python code for performance by making a request to the Django endpoint.
226+
227+
Args:
228+
request: A list of optimization candidate details for refinement
229+
230+
Returns:
231+
-------
232+
- List[OptimizationCandidate]: A list of Optimization Candidates.
233+
234+
"""
235+
payload = [
236+
{
237+
"optimization_id": opt.optimization_id,
238+
"original_source_code": opt.original_source_code,
239+
"read_only_dependency_code": opt.read_only_dependency_code,
240+
"original_line_profiler_results": opt.original_line_profiler_results,
241+
"original_code_runtime": opt.original_code_runtime,
242+
"optimized_source_code": opt.optimized_source_code,
243+
"optimized_explanation": opt.optimized_explanation,
244+
"optimized_line_profiler_results": opt.optimized_line_profiler_results,
245+
"optimized_code_runtime": opt.optimized_code_runtime,
246+
"speedup": opt.speedup,
247+
"trace_id": opt.trace_id,
248+
}
249+
for opt in request
250+
]
251+
logger.info(f"Refining {len(request)} optimizations…")
252+
console.rule()
253+
try:
254+
response = self.make_ai_service_request("/refinement", payload=payload, timeout=600)
255+
except requests.exceptions.RequestException as e:
256+
logger.exception(f"Error generating optimization refinements: {e}")
257+
ph("cli-optimize-error-caught", {"error": str(e)})
258+
return []
259+
260+
if response.status_code == 200:
261+
refined_optimizations = response.json()["refinements"]
262+
logger.info(f"Generated {len(refined_optimizations)} candidate refinements.")
263+
console.rule()
264+
return [
265+
OptimizedCandidate(
266+
source_code=opt["source_code"],
267+
explanation=opt["explanation"],
268+
optimization_id=opt["optimization_id"][:-4] + "refi",
269+
)
270+
for opt in refined_optimizations
271+
]
272+
try:
273+
error = response.json()["error"]
274+
except Exception:
275+
error = response.text
276+
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
277+
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
278+
console.rule()
279+
return []
280+
222281
def log_results( # noqa: D417
223282
self,
224283
function_trace_id: str,
225284
speedup_ratio: dict[str, float | None] | None,
226285
original_runtime: float | None,
227286
optimized_runtime: dict[str, float | None] | None,
228287
is_correct: dict[str, bool] | None,
288+
optimized_line_profiler_results: dict[str, str] | None,
289+
metadata: dict[str, Any] | None,
229290
) -> None:
230291
"""Log features to the database.
231292
@@ -236,6 +297,8 @@ def log_results( # noqa: D417
236297
- original_runtime (Optional[Dict[str, float]]): The original runtime.
237298
- optimized_runtime (Optional[Dict[str, float]]): The optimized runtime.
238299
- is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct.
300+
- optimized_line_profiler_results: line_profiler results for every candidate mapped to their optimization_id
301+
- metadata: contains the best optimization id
239302
240303
"""
241304
payload = {
@@ -245,6 +308,8 @@ def log_results( # noqa: D417
245308
"optimized_runtime": optimized_runtime,
246309
"is_correct": is_correct,
247310
"codeflash_version": codeflash_version,
311+
"optimized_line_profiler_results": optimized_line_profiler_results,
312+
"metadata": metadata,
248313
}
249314
try:
250315
self.make_ai_service_request("/log_features", payload=payload, timeout=5)
@@ -331,3 +396,12 @@ class LocalAiServiceClient(AiServiceClient):
331396
def get_aiservice_base_url(self) -> str:
332397
"""Get the base URL for the local AI service."""
333398
return "http://localhost:8000"
399+
400+
401+
def safe_get_repo_owner_and_name() -> tuple[str | None, str | None]:
402+
try:
403+
git_repo_owner, git_repo_name = get_repo_owner_and_name()
404+
except Exception as e:
405+
logger.warning(f"Could not determine repo owner and name: {e}")
406+
git_repo_owner, git_repo_name = None, None
407+
return git_repo_owner, git_repo_name

codeflash/benchmarking/plugin/plugin.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
if TYPE_CHECKING:
1717
from codeflash.models.models import BenchmarkKey
1818

19-
IS_PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
19+
PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
2020

2121

2222
class CodeFlashBenchmarkPlugin:
@@ -251,8 +251,12 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
251251

252252
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
253253
"""Actual benchmark implementation."""
254+
node_path = getattr(self.request.node, "path", None) or getattr(self.request.node, "fspath", None)
255+
if node_path is None:
256+
raise RuntimeError("Unable to determine test file path from pytest node")
257+
254258
benchmark_module_path = module_name_from_file_path(
255-
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
259+
Path(str(node_path)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
256260
)
257261

258262
benchmark_function_name = self.request.node.name
@@ -282,26 +286,3 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003
282286

283287

284288
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
285-
286-
287-
def pytest_configure(config: pytest.Config) -> None:
288-
"""Register the benchmark marker and disable conflicting plugins."""
289-
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")
290-
291-
if config.getoption("--codeflash-trace") and IS_PYTEST_BENCHMARK_INSTALLED:
292-
config.option.benchmark_disable = True
293-
config.pluginmanager.set_blocked("pytest_benchmark")
294-
config.pluginmanager.set_blocked("pytest-benchmark")
295-
296-
297-
def pytest_addoption(parser: pytest.Parser) -> None:
298-
parser.addoption(
299-
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
300-
)
301-
302-
303-
@pytest.fixture
304-
def benchmark(request: pytest.FixtureRequest) -> object:
305-
if not request.config.getoption("--codeflash-trace"):
306-
return lambda func, *args, **kwargs: func(*args, **kwargs)
307-
return codeflash_benchmark_plugin.Benchmark(request)

codeflash/benchmarking/pytest_new_process_trace_benchmarks.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
benchmarks_root = sys.argv[1]
88
tests_root = sys.argv[2]
99
trace_file = sys.argv[3]
10-
# current working directory
1110
project_root = Path.cwd()
11+
1212
if __name__ == "__main__":
1313
import pytest
1414

15+
orig_recursion_limit = sys.getrecursionlimit()
16+
sys.setrecursionlimit(orig_recursion_limit * 2)
17+
1518
try:
1619
codeflash_benchmark_plugin.setup(trace_file, project_root)
1720
codeflash_trace.setup(trace_file)
@@ -32,9 +35,12 @@
3235
"addopts=",
3336
],
3437
plugins=[codeflash_benchmark_plugin],
35-
) # Errors will be printed to stdout, not stderr
36-
38+
)
3739
except Exception as e:
3840
print(f"Failed to collect tests: {e!s}", file=sys.stderr)
3941
exitcode = -1
42+
finally:
43+
# Restore the original recursion limit
44+
sys.setrecursionlimit(orig_recursion_limit)
45+
4046
sys.exit(exitcode)

codeflash/benchmarking/replay_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import re
34
import sqlite3
45
import textwrap
56
from pathlib import Path
@@ -14,6 +15,8 @@
1415
if TYPE_CHECKING:
1516
from collections.abc import Generator
1617

18+
benchmark_context_cleaner = re.compile(r"[^a-zA-Z0-9_]+")
19+
1720

1821
def get_next_arg_and_return(
1922
trace_file: str,
@@ -46,6 +49,16 @@ def get_function_alias(module: str, function_name: str) -> str:
4649
return "_".join(module.split(".")) + "_" + function_name
4750

4851

52+
def get_unique_test_name(module: str, function_name: str, benchmark_name: str, class_name: str | None = None) -> str:
53+
clean_benchmark = benchmark_context_cleaner.sub("_", benchmark_name).strip("_")
54+
55+
base_alias = get_function_alias(module, function_name)
56+
if class_name:
57+
class_alias = get_function_alias(module, class_name)
58+
return f"{class_alias}_{function_name}_{clean_benchmark}"
59+
return f"{base_alias}_{clean_benchmark}"
60+
61+
4962
def create_trace_replay_test_code(
5063
trace_file: str,
5164
functions_data: list[dict[str, Any]],
@@ -209,7 +222,8 @@ def create_trace_replay_test_code(
209222
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
210223

211224
test_template += " " if test_framework == "unittest" else ""
212-
test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n"
225+
unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name)
226+
test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n"
213227

214228
return imports + "\n" + metadata + "\n" + test_template
215229

0 commit comments

Comments
 (0)