Skip to content

Commit c1ae81e

Browse files
working version
1 parent 5703889 commit c1ae81e

File tree

4 files changed

+103
-113
lines changed

4 files changed

+103
-113
lines changed

codeflash/api/aiservice.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -294,50 +294,39 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
294294
console.rule()
295295
return []
296296

297-
def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest]) -> list[OptimizedCandidate]:
297+
def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None:
298298
"""Optimize the given python code for performance by making a request to the Django endpoint.
299299
300300
Args:
301-
request: A list of optimization candidate details for refinement
301+
request: optimization candidate details for refinement
302302
303303
Returns:
304304
-------
305-
- List[OptimizationCandidate]: A list of Optimization Candidates.
305+
- OptimizationCandidate: new fixed candidate.
306306
307307
"""
308-
payload = [
309-
{
310-
"optimization_id": opt.optimization_id,
311-
"original_source_code": opt.original_source_code,
312-
"modified_source_code": opt.modified_source_code,
313-
"test_details": opt.test_details,
314-
"trace_id": opt.trace_id,
315-
}
316-
for opt in request
317-
]
318-
# logger.debug(f"Repair {len(request)} optimizations…")
319308
console.rule()
320309
try:
321-
response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120)
310+
response = self.make_ai_service_request("/code_repair", payload=request, timeout=120)
322311
except requests.exceptions.RequestException as e:
323312
logger.exception(f"Error generating optimization repair: {e}")
324313
ph("cli-optimize-error-caught", {"error": str(e)})
325314
return []
326315

327316
if response.status_code == 200:
328-
refined_optimizations = response.json()["code_repairs"]
329-
# logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
317+
refined_optimization = response.json()
330318
console.rule()
331319

332-
refinements = self._get_valid_candidates(refined_optimizations)
333-
return [
334-
OptimizedCandidate(
335-
source_code=c.source_code,
336-
explanation=c.explanation,
337-
optimization_id=c.optimization_id[:-4] + "cdrp",
338-
)
339-
for c in refinements
340-
]
320+
refinements = self._get_valid_candidates([refined_optimization])
321+
if not refinements:
322+
logger.error("Code repair failed to generate a valid candidate.")
323+
return None
324+
325+
return OptimizedCandidate(
326+
source_code=refinements[0].source_code,
327+
explanation=refinements[0].explanation,
328+
optimization_id=refinements[0].optimization_id[:-4] + "cdrp",
329+
)
341330

342331
try:
343332
error = response.json()["error"]
@@ -346,7 +335,7 @@ def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest])
346335
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
347336
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
348337
console.rule()
349-
return []
338+
return None
350339

351340
def get_new_explanation( # noqa: D417
352341
self,

codeflash/models/models.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,33 @@ class AIServiceRefinerRequest:
4848
function_references: str | None = None
4949

5050

51+
class TestDiffScope(str, Enum):
52+
RETURN_VALUE = "return_value"
53+
STDOUT = "stdout"
54+
DID_PASS = "did_pass" # noqa: S105
55+
TIMED_OUT = "timed_out"
56+
57+
58+
@dataclass
59+
class TestDiff:
60+
scope: TestDiffScope
61+
original_pass: bool
62+
candidate_pass: bool
63+
64+
original_value: str | None = None
65+
candidate_value: str | None = None
66+
test_src_code: Optional[str] = None
67+
candidate_pytest_error: Optional[str] = None
68+
original_pytest_error: Optional[str] = None
69+
70+
5171
@dataclass(frozen=True)
5272
class AIServiceCodeRepairRequest:
5373
optimization_id: str
5474
original_source_code: str
5575
modified_source_code: str
56-
test_details: str
5776
trace_id: str
77+
test_diffs: list[TestDiff]
5878

5979

6080
# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully

codeflash/optimization/function_optimizer.py

Lines changed: 58 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
CoverageData,
117117
FunctionCalledInTest,
118118
FunctionSource,
119+
TestDiff,
119120
)
120121
from codeflash.verification.verification_utils import TestConfig
121122

@@ -685,32 +686,15 @@ def determine_best_candidate(
685686
baseline_results=original_code_baseline,
686687
original_helper_code=original_helper_code,
687688
file_path_to_helper_classes=file_path_to_helper_classes,
689+
code_context=code_context,
690+
candidate=candidate,
691+
exp_type=exp_type,
688692
)
689693
console.rule()
690694
if not is_successful(run_results):
691695
optimized_runtimes[candidate.optimization_id] = None
692696
is_correct[candidate.optimization_id] = False
693697
speedup_ratios[candidate.optimization_id] = None
694-
fail_value = run_results.value
695-
if (
696-
fail_value.strip() != "Test results did not match the test results of the original code."
697-
and len(future_all_refinements) <= 3
698-
and not candidate.optimization_id.endswith("cdrp")
699-
):
700-
# # queue corresponding code repair optimization for best optimization
701-
future_all_refinements.append(
702-
self.code_repair_optimizations(
703-
original_source_code=code_context.read_writable_code.markdown,
704-
modified_source_code=candidate.source_code.markdown,
705-
test_details=fail_value,
706-
trace_id=self.function_trace_id[:-4] + exp_type
707-
if self.experiment_id
708-
else self.function_trace_id,
709-
ai_service_client=ai_service_client,
710-
executor=self.executor,
711-
optimization_id=candidate.optimization_id,
712-
)
713-
)
714698
else:
715699
candidate_result: OptimizedCandidateResult = run_results.unwrap()
716700
best_test_runtime = candidate_result.best_test_runtime
@@ -978,22 +962,19 @@ def code_repair_optimizations(
978962
self,
979963
original_source_code: str,
980964
modified_source_code: str,
981-
test_details: str,
965+
test_diffs: list[TestDiff],
982966
trace_id: str,
983967
optimization_id: str,
984968
ai_service_client: AiServiceClient,
985-
executor: concurrent.futures.ThreadPoolExecutor,
986-
) -> concurrent.futures.Future:
987-
request = [
988-
AIServiceCodeRepairRequest(
989-
optimization_id=optimization_id,
990-
original_source_code=original_source_code,
991-
modified_source_code=modified_source_code,
992-
test_details=test_details,
993-
trace_id=trace_id,
994-
)
995-
]
996-
return executor.submit(ai_service_client.optimize_python_code_repair, request=request)
969+
) -> OptimizedCandidate | None:
970+
request = AIServiceCodeRepairRequest(
971+
optimization_id=optimization_id,
972+
original_source_code=original_source_code,
973+
modified_source_code=modified_source_code,
974+
test_diffs=test_diffs,
975+
trace_id=trace_id,
976+
)
977+
return ai_service_client.optimize_python_code_repair(request=request)
997978

998979
def log_successful_optimization(
999980
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
@@ -1920,6 +1901,9 @@ def run_optimized_candidate(
19201901
baseline_results: OriginalCodeBaseline,
19211902
original_helper_code: dict[Path, str],
19221903
file_path_to_helper_classes: dict[Path, set[str]],
1904+
code_context: CodeOptimizationContext,
1905+
candidate: OptimizedCandidate,
1906+
exp_type: str,
19231907
) -> Result[OptimizedCandidateResult, str]:
19241908
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018
19251909

@@ -1980,29 +1964,50 @@ def run_optimized_candidate(
19801964
# if the test unmatched percentage is greater than 50%, we can't fix it
19811965
return self.get_results_not_matched_error()
19821966

1983-
logger.info("running code repair...")
1984-
# not sure if all return types will be convertible to string
1985-
diff_per_test_fn = {}
1986-
for diff in diffs:
1987-
try:
1988-
diff_per_test_fn[diff.test_src_code] = (
1989-
diff_per_test_fn.setdefault(diff.test_src_code, "")
1990-
+ f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.candidate_pytest_error}\n"
1991-
)
1967+
if candidate.optimization_id.endswith("cdrp"):
1968+
# prevent looping for now
1969+
return self.get_results_not_matched_error()
1970+
1971+
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
1972+
1973+
with progress_bar("The test results are not matching, let me see if I can fix this"):
1974+
new_candidate = self.code_repair_optimizations(
1975+
original_source_code=code_context.read_writable_code.markdown,
1976+
modified_source_code=candidate.source_code.markdown,
1977+
test_diffs=diffs,
1978+
trace_id=self.function_trace_id[:-4] + exp_type
1979+
if self.experiment_id
1980+
else self.function_trace_id,
1981+
ai_service_client=ai_service_client,
1982+
optimization_id=candidate.optimization_id,
1983+
)
1984+
if not new_candidate:
1985+
return Failure("Code repair failed to generate a valid candidate.")
1986+
1987+
code_print(new_candidate.source_code.flat)
19921988

1993-
except Exception as e:
1994-
sentry_sdk.capture_exception(e)
1995-
logger.exception(e)
1996-
return self.get_results_not_matched_error()
19971989
try:
1998-
test_issues = "\n".join(
1999-
f"{test_fn_def}\n{value}" for test_fn_def, value in diff_per_test_fn.items()
1990+
did_update = self.replace_function_and_helpers_with_optimized_code(
1991+
code_context=code_context,
1992+
optimized_code=new_candidate.source_code,
1993+
original_helper_code=original_helper_code,
20001994
)
2001-
except Exception as e:
2002-
sentry_sdk.capture_exception(e)
2003-
logger.exception(e)
2004-
return self.get_results_not_matched_error()
2005-
return Failure(test_issues)
1995+
if did_update:
1996+
return self.run_optimized_candidate(
1997+
optimization_candidate_index=optimization_candidate_index,
1998+
baseline_results=baseline_results,
1999+
original_helper_code=original_helper_code,
2000+
file_path_to_helper_classes=file_path_to_helper_classes,
2001+
code_context=code_context,
2002+
candidate=new_candidate,
2003+
exp_type=exp_type,
2004+
)
2005+
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
2006+
logger.error(e)
2007+
self.write_code_and_helpers(
2008+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
2009+
)
2010+
return Failure("Code repair failed to generate a valid candidate.")
20062011

20072012
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
20082013

codeflash/verification/equivalence.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from __future__ import annotations
22

33
import sys
4-
from dataclasses import dataclass
5-
from enum import Enum
6-
from typing import TYPE_CHECKING, Optional
4+
from typing import TYPE_CHECKING
75

86
from codeflash.cli_cmds.console import logger
9-
from codeflash.models.models import TestResults, TestType, VerificationType
7+
from codeflash.models.models import TestDiff, TestDiffScope, TestResults, TestType, VerificationType
108
from codeflash.verification.comparator import comparator
119

1210
if TYPE_CHECKING:
@@ -15,26 +13,6 @@
1513
INCREASED_RECURSION_LIMIT = 5000
1614

1715

18-
class TestDiffScope(Enum):
19-
RETURN_VALUE = "return_value"
20-
STDOUT = "stdout"
21-
DID_PASS = "did_pass" # noqa: S105
22-
TIMED_OUT = "timed_out"
23-
24-
25-
@dataclass
26-
class TestDiff:
27-
scope: TestDiffScope
28-
original_value: any
29-
candidate_value: any
30-
original_pass: bool
31-
candidate_pass: bool
32-
33-
test_src_code: Optional[str] = None
34-
candidate_pytest_error: Optional[str] = None
35-
original_pytest_error: Optional[str] = None
36-
37-
3816
def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]:
3917
# This is meant to be only called with test results for the first loop index
4018
if len(original_results) == 0 or len(candidate_results) == 0:
@@ -87,8 +65,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
8765
test_src_code = original_test_result.id.get_src_code(original_test_result.file_name)
8866
test_diff = TestDiff(
8967
scope=TestDiffScope.RETURN_VALUE,
90-
original_value=original_test_result.return_value,
91-
candidate_value=cdd_test_result.return_value,
68+
original_value=f"{original_test_result.return_value!r}",
69+
candidate_value=f"{cdd_test_result.return_value!r}",
9270
test_src_code=test_src_code,
9371
candidate_pytest_error=cdd_pytest_error,
9472
original_pass=original_test_result.did_pass,
@@ -97,8 +75,6 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
9775
)
9876
if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
9977
test_diff.scope = TestDiffScope.RETURN_VALUE
100-
test_diff.original_value = original_test_result.return_value
101-
test_diff.candidate_value = cdd_test_result.return_value
10278
test_diffs.append(test_diff)
10379

10480
try:
@@ -117,8 +93,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
11793
original_test_result.stdout, cdd_test_result.stdout
11894
):
11995
test_diff.scope = TestDiffScope.STDOUT
120-
test_diff.original_value = original_test_result.stdout
121-
test_diff.candidate_value = cdd_test_result.stdout
96+
test_diff.original_value = str(original_test_result.stdout)
97+
test_diff.candidate_value = str(cdd_test_result.stdout)
12298
test_diffs.append(test_diff)
12399

124100
if original_test_result.test_type in {
@@ -128,8 +104,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
128104
TestType.REPLAY_TEST,
129105
} and (cdd_test_result.did_pass != original_test_result.did_pass):
130106
test_diff.scope = TestDiffScope.DID_PASS
131-
test_diff.original_value = original_test_result.did_pass
132-
test_diff.candidate_value = cdd_test_result.did_pass
107+
test_diff.original_value = str(original_test_result.did_pass)
108+
test_diff.candidate_value = str(cdd_test_result.did_pass)
133109
test_diffs.append(test_diff)
134110

135111
sys.setrecursionlimit(original_recursion_limit)

0 commit comments

Comments
 (0)