Skip to content

Commit b478f10

Browse files
authored
Merge pull request #555 from codeflash-ai/refinement
Refinement
2 parents 52083c4 + f4c6c7d commit b478f10

File tree

6 files changed

+323
-57
lines changed

6 files changed

+323
-57
lines changed

codeflash/api/aiservice.py

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from codeflash.cli_cmds.console import console, logger
1313
from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled
1414
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
15-
from codeflash.models.models import OptimizedCandidate
15+
from codeflash.models.ExperimentMetadata import ExperimentMetadata
16+
from codeflash.models.models import AIServiceRefinerRequest, OptimizedCandidate
1617
from codeflash.telemetry.posthog_cf import ph
1718
from codeflash.version import __version__ as codeflash_version
1819

@@ -21,6 +22,7 @@
2122

2223
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
2324
from codeflash.models.ExperimentMetadata import ExperimentMetadata
25+
from codeflash.models.models import AIServiceRefinerRequest
2426

2527

2628
class AiServiceClient:
@@ -36,7 +38,11 @@ def get_aiservice_base_url(self) -> str:
3638
return "https://app.codeflash.ai"
3739

3840
def make_ai_service_request(
39-
self, endpoint: str, method: str = "POST", payload: dict[str, Any] | None = None, timeout: float | None = None
41+
self,
42+
endpoint: str,
43+
method: str = "POST",
44+
payload: dict[str, Any] | list[dict[str, Any]] | None = None,
45+
timeout: float | None = None,
4046
) -> requests.Response:
4147
"""Make an API request to the given endpoint on the AI service.
4248
@@ -98,11 +104,7 @@ def optimize_python_code( # noqa: D417
98104
99105
"""
100106
start_time = time.perf_counter()
101-
try:
102-
git_repo_owner, git_repo_name = get_repo_owner_and_name()
103-
except Exception as e:
104-
logger.warning(f"Could not determine repo owner and name: {e}")
105-
git_repo_owner, git_repo_name = None, None
107+
git_repo_owner, git_repo_name = safe_get_repo_owner_and_name()
106108

107109
payload = {
108110
"source_code": source_code,
@@ -219,13 +221,72 @@ def optimize_python_code_line_profiler( # noqa: D417
219221
console.rule()
220222
return []
221223

224+
def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
225+
"""Optimize the given python code for performance by making a request to the Django endpoint.
226+
227+
Args:
228+
request: A list of optimization candidate details for refinement
229+
230+
Returns:
231+
-------
232+
- List[OptimizationCandidate]: A list of Optimization Candidates.
233+
234+
"""
235+
payload = [
236+
{
237+
"optimization_id": opt.optimization_id,
238+
"original_source_code": opt.original_source_code,
239+
"read_only_dependency_code": opt.read_only_dependency_code,
240+
"original_line_profiler_results": opt.original_line_profiler_results,
241+
"original_code_runtime": opt.original_code_runtime,
242+
"optimized_source_code": opt.optimized_source_code,
243+
"optimized_explanation": opt.optimized_explanation,
244+
"optimized_line_profiler_results": opt.optimized_line_profiler_results,
245+
"optimized_code_runtime": opt.optimized_code_runtime,
246+
"speedup": opt.speedup,
247+
"trace_id": opt.trace_id,
248+
}
249+
for opt in request
250+
]
251+
logger.info(f"Refining {len(request)} optimizations…")
252+
console.rule()
253+
try:
254+
response = self.make_ai_service_request("/refinement", payload=payload, timeout=600)
255+
except requests.exceptions.RequestException as e:
256+
logger.exception(f"Error generating optimization refinements: {e}")
257+
ph("cli-optimize-error-caught", {"error": str(e)})
258+
return []
259+
260+
if response.status_code == 200:
261+
refined_optimizations = response.json()["refinements"]
262+
logger.info(f"Generated {len(refined_optimizations)} candidate refinements.")
263+
console.rule()
264+
return [
265+
OptimizedCandidate(
266+
source_code=opt["source_code"],
267+
explanation=opt["explanation"],
268+
optimization_id=opt["optimization_id"][:-4] + "refi",
269+
)
270+
for opt in refined_optimizations
271+
]
272+
try:
273+
error = response.json()["error"]
274+
except Exception:
275+
error = response.text
276+
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
277+
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
278+
console.rule()
279+
return []
280+
222281
def log_results( # noqa: D417
223282
self,
224283
function_trace_id: str,
225284
speedup_ratio: dict[str, float | None] | None,
226285
original_runtime: float | None,
227286
optimized_runtime: dict[str, float | None] | None,
228287
is_correct: dict[str, bool] | None,
288+
optimized_line_profiler_results: dict[str, str] | None,
289+
metadata: dict[str, Any] | None,
229290
) -> None:
230291
"""Log features to the database.
231292
@@ -236,6 +297,8 @@ def log_results( # noqa: D417
236297
- original_runtime (Optional[Dict[str, float]]): The original runtime.
237298
- optimized_runtime (Optional[Dict[str, float]]): The optimized runtime.
238299
- is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct.
300+
- optimized_line_profiler_results: line_profiler results for every candidate mapped to their optimization_id
301+
- metadata: contains the best optimization id
239302
240303
"""
241304
payload = {
@@ -245,6 +308,8 @@ def log_results( # noqa: D417
245308
"optimized_runtime": optimized_runtime,
246309
"is_correct": is_correct,
247310
"codeflash_version": codeflash_version,
311+
"optimized_line_profiler_results": optimized_line_profiler_results,
312+
"metadata": metadata,
248313
}
249314
try:
250315
self.make_ai_service_request("/log_features", payload=payload, timeout=5)
@@ -331,3 +396,12 @@ class LocalAiServiceClient(AiServiceClient):
331396
def get_aiservice_base_url(self) -> str:
332397
"""Get the base URL for the local AI service."""
333398
return "http://localhost:8000"
399+
400+
401+
def safe_get_repo_owner_and_name() -> tuple[str | None, str | None]:
402+
try:
403+
git_repo_owner, git_repo_name = get_repo_owner_and_name()
404+
except Exception as e:
405+
logger.warning(f"Could not determine repo owner and name: {e}")
406+
git_repo_owner, git_repo_name = None, None
407+
return git_repo_owner, git_repo_name

codeflash/code_utils/code_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import difflib
45
import os
56
import re
67
import shutil
@@ -19,6 +20,50 @@
1920
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
2021

2122

23+
def diff_length(a: str, b: str) -> int:
24+
"""Compute the length (in characters) of the unified diff between two strings.
25+
26+
Args:
27+
a (str): Original string.
28+
b (str): Modified string.
29+
30+
Returns:
31+
int: Total number of characters in the diff.
32+
33+
"""
34+
# Split input strings into lines for line-by-line diff
35+
a_lines = a.splitlines(keepends=True)
36+
b_lines = b.splitlines(keepends=True)
37+
38+
# Compute unified diff
39+
diff_lines = list(difflib.unified_diff(a_lines, b_lines, lineterm=""))
40+
41+
# Join all lines with newline to calculate total diff length
42+
diff_text = "\n".join(diff_lines)
43+
44+
return len(diff_text)
45+
46+
47+
def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
48+
"""Create a dictionary from a list of ints, mapping the original index to its rank.
49+
50+
This version uses a more compact, "Pythonic" implementation.
51+
52+
Args:
53+
int_array: A list of integers.
54+
55+
Returns:
56+
A dictionary where keys are original indices and values are the
57+
rank of the element in ascending order.
58+
59+
"""
60+
# Sort the indices of the array based on their corresponding values
61+
sorted_indices = sorted(range(len(int_array)), key=lambda i: int_array[i])
62+
63+
# Create a dictionary mapping the original index to its rank (its position in the sorted list)
64+
return {original_index: rank for rank, original_index in enumerate(sorted_indices)}
65+
66+
2267
@contextmanager
2368
def custom_addopts() -> None:
2469
pyproject_file = find_pyproject_toml()

codeflash/models/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,22 @@
2727
from codeflash.code_utils.env_utils import is_end_to_end
2828
from codeflash.verification.comparator import comparator
2929

30+
31+
@dataclass(frozen=True)
32+
class AIServiceRefinerRequest:
33+
optimization_id: str
34+
original_source_code: str
35+
read_only_dependency_code: str
36+
original_code_runtime: str
37+
optimized_source_code: str
38+
optimized_explanation: str
39+
optimized_code_runtime: str
40+
speedup: str
41+
trace_id: str
42+
original_line_profiler_results: str
43+
optimized_line_profiler_results: str
44+
45+
3046
# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
3147
# qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name
3248
# of the module is foo.eggs.
@@ -76,11 +92,13 @@ def __hash__(self) -> int:
7692
class BestOptimization(BaseModel):
7793
candidate: OptimizedCandidate
7894
helper_functions: list[FunctionSource]
95+
code_context: CodeOptimizationContext
7996
runtime: int
8097
replay_performance_gain: Optional[dict[BenchmarkKey, float]] = None
8198
winning_behavior_test_results: TestResults
8299
winning_benchmarking_test_results: TestResults
83100
winning_replay_benchmarking_test_results: Optional[TestResults] = None
101+
line_profiler_test_results: dict
84102

85103

86104
@dataclass(frozen=True)

0 commit comments

Comments
 (0)