Skip to content

Commit 32264a3

Browse files
authored
Merge branch 'main' into feat-staging
2 parents 7170b5c + d5ec766 commit 32264a3

28 files changed

+709
-151
lines changed

.github/workflows/unit-tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
strategy:
1212
fail-fast: false
1313
matrix:
14-
python-version: ["3.9", "3.10", "3.11", "3.12"]
14+
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
1515
continue-on-error: true
1616
runs-on: ubuntu-latest
1717
steps:
@@ -30,4 +30,4 @@ jobs:
3030
run: uv sync
3131

3232
- name: Unit tests
33-
run: uv run pytest tests/ --benchmark-skip -m "not ci_skip"
33+
run: uv run pytest tests/
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""CodeFlash Benchmark - Pytest benchmarking plugin for codeflash.ai."""
2+
3+
__version__ = "0.1.0"
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
import importlib.util
4+
5+
import pytest
6+
7+
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
8+
9+
PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
10+
11+
12+
def pytest_configure(config: pytest.Config) -> None:
13+
"""Register the benchmark marker and disable conflicting plugins."""
14+
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")
15+
16+
if config.getoption("--codeflash-trace") and PYTEST_BENCHMARK_INSTALLED:
17+
config.option.benchmark_disable = True
18+
config.pluginmanager.set_blocked("pytest_benchmark")
19+
config.pluginmanager.set_blocked("pytest-benchmark")
20+
21+
22+
def pytest_addoption(parser: pytest.Parser) -> None:
23+
parser.addoption(
24+
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
25+
)
26+
27+
28+
@pytest.fixture
29+
def benchmark(request: pytest.FixtureRequest) -> object:
30+
"""Benchmark fixture that works with or without pytest-benchmark installed."""
31+
config = request.config
32+
33+
# If --codeflash-trace is enabled, use our implementation
34+
if config.getoption("--codeflash-trace"):
35+
return codeflash_benchmark_plugin.Benchmark(request)
36+
37+
# If pytest-benchmark is installed and --codeflash-trace is not enabled,
38+
# return the normal pytest-benchmark fixture
39+
if PYTEST_BENCHMARK_INSTALLED:
40+
from pytest_benchmark.fixture import BenchmarkFixture as BSF # noqa: N814
41+
42+
bs = getattr(config, "_benchmarksession", None)
43+
if bs and bs.skip:
44+
pytest.skip("Benchmarks are skipped (--benchmark-skip was used).")
45+
46+
node = request.node
47+
marker = node.get_closest_marker("benchmark")
48+
options = dict(marker.kwargs) if marker else {}
49+
50+
if bs:
51+
return BSF(
52+
node,
53+
add_stats=bs.benchmarks.append,
54+
logger=bs.logger,
55+
warner=request.node.warn,
56+
disabled=bs.disabled,
57+
**dict(bs.options, **options),
58+
)
59+
return lambda func, *args, **kwargs: func(*args, **kwargs)
60+
61+
return lambda func, *args, **kwargs: func(*args, **kwargs)

codeflash-benchmark/pyproject.toml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
[project]
2+
name = "codeflash-benchmark"
3+
version = "0.1.0"
4+
description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization"
5+
authors = [{ name = "CodeFlash Inc.", email = "[email protected]" }]
6+
requires-python = ">=3.9"
7+
readme = "README.md"
8+
license = {text = "BSL-1.1"}
9+
keywords = [
10+
"codeflash",
11+
"benchmark",
12+
"pytest",
13+
"performance",
14+
"testing",
15+
]
16+
dependencies = [
17+
"pytest>=7.0.0,!=8.3.4",
18+
]
19+
20+
[project.urls]
21+
Homepage = "https://codeflash.ai"
22+
Repository = "https://github.com/codeflash-ai/codeflash-benchmark"
23+
24+
[project.entry-points.pytest11]
25+
codeflash-benchmark = "codeflash_benchmark.plugin"
26+
27+
[build-system]
28+
requires = ["setuptools>=45", "wheel", "setuptools_scm"]
29+
build-backend = "setuptools.build_meta"
30+
31+
[tool.setuptools]
32+
packages = ["codeflash_benchmark"]

codeflash/api/aiservice.py

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from codeflash.cli_cmds.console import console, logger
1313
from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled
1414
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
15-
from codeflash.models.models import OptimizedCandidate
15+
from codeflash.models.ExperimentMetadata import ExperimentMetadata
16+
from codeflash.models.models import AIServiceRefinerRequest, OptimizedCandidate
1617
from codeflash.telemetry.posthog_cf import ph
1718
from codeflash.version import __version__ as codeflash_version
1819

@@ -21,6 +22,7 @@
2122

2223
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
2324
from codeflash.models.ExperimentMetadata import ExperimentMetadata
25+
from codeflash.models.models import AIServiceRefinerRequest
2426

2527

2628
class AiServiceClient:
@@ -36,7 +38,11 @@ def get_aiservice_base_url(self) -> str:
3638
return "https://app.codeflash.ai"
3739

3840
def make_ai_service_request(
39-
self, endpoint: str, method: str = "POST", payload: dict[str, Any] | None = None, timeout: float | None = None
41+
self,
42+
endpoint: str,
43+
method: str = "POST",
44+
payload: dict[str, Any] | list[dict[str, Any]] | None = None,
45+
timeout: float | None = None,
4046
) -> requests.Response:
4147
"""Make an API request to the given endpoint on the AI service.
4248
@@ -98,11 +104,7 @@ def optimize_python_code( # noqa: D417
98104
99105
"""
100106
start_time = time.perf_counter()
101-
try:
102-
git_repo_owner, git_repo_name = get_repo_owner_and_name()
103-
except Exception as e:
104-
logger.warning(f"Could not determine repo owner and name: {e}")
105-
git_repo_owner, git_repo_name = None, None
107+
git_repo_owner, git_repo_name = safe_get_repo_owner_and_name()
106108

107109
payload = {
108110
"source_code": source_code,
@@ -219,13 +221,72 @@ def optimize_python_code_line_profiler( # noqa: D417
219221
console.rule()
220222
return []
221223

224+
def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
225+
"""Optimize the given python code for performance by making a request to the Django endpoint.
226+
227+
Args:
228+
request: A list of optimization candidate details for refinement
229+
230+
Returns:
231+
-------
232+
- List[OptimizationCandidate]: A list of Optimization Candidates.
233+
234+
"""
235+
payload = [
236+
{
237+
"optimization_id": opt.optimization_id,
238+
"original_source_code": opt.original_source_code,
239+
"read_only_dependency_code": opt.read_only_dependency_code,
240+
"original_line_profiler_results": opt.original_line_profiler_results,
241+
"original_code_runtime": opt.original_code_runtime,
242+
"optimized_source_code": opt.optimized_source_code,
243+
"optimized_explanation": opt.optimized_explanation,
244+
"optimized_line_profiler_results": opt.optimized_line_profiler_results,
245+
"optimized_code_runtime": opt.optimized_code_runtime,
246+
"speedup": opt.speedup,
247+
"trace_id": opt.trace_id,
248+
}
249+
for opt in request
250+
]
251+
logger.info(f"Refining {len(request)} optimizations…")
252+
console.rule()
253+
try:
254+
response = self.make_ai_service_request("/refinement", payload=payload, timeout=600)
255+
except requests.exceptions.RequestException as e:
256+
logger.exception(f"Error generating optimization refinements: {e}")
257+
ph("cli-optimize-error-caught", {"error": str(e)})
258+
return []
259+
260+
if response.status_code == 200:
261+
refined_optimizations = response.json()["refinements"]
262+
logger.info(f"Generated {len(refined_optimizations)} candidate refinements.")
263+
console.rule()
264+
return [
265+
OptimizedCandidate(
266+
source_code=opt["source_code"],
267+
explanation=opt["explanation"],
268+
optimization_id=opt["optimization_id"][:-4] + "refi",
269+
)
270+
for opt in refined_optimizations
271+
]
272+
try:
273+
error = response.json()["error"]
274+
except Exception:
275+
error = response.text
276+
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
277+
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
278+
console.rule()
279+
return []
280+
222281
def log_results( # noqa: D417
223282
self,
224283
function_trace_id: str,
225284
speedup_ratio: dict[str, float | None] | None,
226285
original_runtime: float | None,
227286
optimized_runtime: dict[str, float | None] | None,
228287
is_correct: dict[str, bool] | None,
288+
optimized_line_profiler_results: dict[str, str] | None,
289+
metadata: dict[str, Any] | None,
229290
) -> None:
230291
"""Log features to the database.
231292
@@ -236,6 +297,8 @@ def log_results( # noqa: D417
236297
- original_runtime (Optional[Dict[str, float]]): The original runtime.
237298
- optimized_runtime (Optional[Dict[str, float]]): The optimized runtime.
238299
- is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct.
300+
- optimized_line_profiler_results: line_profiler results for every candidate mapped to their optimization_id
301+
- metadata: contains the best optimization id
239302
240303
"""
241304
payload = {
@@ -245,6 +308,8 @@ def log_results( # noqa: D417
245308
"optimized_runtime": optimized_runtime,
246309
"is_correct": is_correct,
247310
"codeflash_version": codeflash_version,
311+
"optimized_line_profiler_results": optimized_line_profiler_results,
312+
"metadata": metadata,
248313
}
249314
try:
250315
self.make_ai_service_request("/log_features", payload=payload, timeout=5)
@@ -331,3 +396,12 @@ class LocalAiServiceClient(AiServiceClient):
331396
def get_aiservice_base_url(self) -> str:
332397
"""Get the base URL for the local AI service."""
333398
return "http://localhost:8000"
399+
400+
401+
def safe_get_repo_owner_and_name() -> tuple[str | None, str | None]:
402+
try:
403+
git_repo_owner, git_repo_name = get_repo_owner_and_name()
404+
except Exception as e:
405+
logger.warning(f"Could not determine repo owner and name: {e}")
406+
git_repo_owner, git_repo_name = None, None
407+
return git_repo_owner, git_repo_name

codeflash/api/cfapi.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,16 @@ def mark_optimization_success(trace_id: str, *, is_optimization_found: bool) ->
316316
"""
317317
payload = {"trace_id": trace_id, "is_optimization_found": is_optimization_found}
318318
return make_cfapi_request(endpoint="/mark-as-success", method="POST", payload=payload)
319+
320+
321+
def send_completion_email() -> Response:
322+
"""Send an email notification when codeflash --all completes."""
323+
try:
324+
owner, repo = get_repo_owner_and_name()
325+
except Exception as e:
326+
sentry_sdk.capture_exception(e)
327+
response = requests.Response()
328+
response.status_code = 500
329+
return response
330+
payload = {"owner": owner, "repo": repo}
331+
return make_cfapi_request(endpoint="/send-completion-email", method="POST", payload=payload)

codeflash/benchmarking/plugin/plugin.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from __future__ import annotations
22

3+
import importlib.util
34
import os
45
import sqlite3
56
import sys
67
import time
78
from pathlib import Path
9+
from typing import TYPE_CHECKING
810

911
import pytest
1012

1113
from codeflash.benchmarking.codeflash_trace import codeflash_trace
1214
from codeflash.code_utils.code_utils import module_name_from_file_path
13-
from codeflash.models.models import BenchmarkKey
15+
16+
if TYPE_CHECKING:
17+
from codeflash.models.models import BenchmarkKey
18+
19+
PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
1420

1521

1622
class CodeFlashBenchmarkPlugin:
@@ -71,6 +77,8 @@ def close(self) -> None:
7177

7278
@staticmethod
7379
def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]:
80+
from codeflash.models.models import BenchmarkKey
81+
7482
"""Process the trace file and extract timing data for all functions.
7583
7684
Args:
@@ -131,6 +139,8 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
131139

132140
@staticmethod
133141
def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
142+
from codeflash.models.models import BenchmarkKey
143+
134144
"""Extract total benchmark timings from trace files.
135145
136146
Args:
@@ -199,23 +209,6 @@ def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, AR
199209
# Close the database connection
200210
self.close()
201211

202-
@staticmethod
203-
def pytest_addoption(parser: pytest.Parser) -> None:
204-
parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing")
205-
206-
@staticmethod
207-
def pytest_plugin_registered(plugin, manager) -> None: # noqa: ANN001
208-
# Not necessary since run with -p no:benchmark, but just in case
209-
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
210-
manager.unregister(plugin)
211-
212-
@staticmethod
213-
def pytest_configure(config: pytest.Config) -> None:
214-
"""Register the benchmark marker."""
215-
config.addinivalue_line(
216-
"markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing"
217-
)
218-
219212
@staticmethod
220213
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
221214
# Skip tests that don't have the benchmark fixture
@@ -258,9 +251,14 @@ def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003, ANN202
258251

259252
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
260253
"""Actual benchmark implementation."""
254+
node_path = getattr(self.request.node, "path", None) or getattr(self.request.node, "fspath", None)
255+
if node_path is None:
256+
raise RuntimeError("Unable to determine test file path from pytest node")
257+
261258
benchmark_module_path = module_name_from_file_path(
262-
Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)
259+
Path(str(node_path)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
263260
)
261+
264262
benchmark_function_name = self.request.node.name
265263
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack # noqa: SLF001
266264
# Set env vars
@@ -286,13 +284,5 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003
286284

287285
return result
288286

289-
@staticmethod
290-
@pytest.fixture
291-
def benchmark(request: pytest.FixtureRequest) -> object:
292-
if not request.config.getoption("--codeflash-trace"):
293-
return None
294-
295-
return CodeFlashBenchmarkPlugin.Benchmark(request)
296-
297287

298288
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()

0 commit comments

Comments
 (0)