Skip to content

Commit 8a28d0d

Browse files
Merge branch 'feat/feedback-loop-for-unmatched-test-results' of github.com:codeflash-ai/codeflash into feat/feedback-loop-for-unmatched-test-results
2 parents 6ea2545 + 5a7c356 commit 8a28d0d

File tree

4 files changed

+147
-27
lines changed

4 files changed

+147
-27
lines changed

codeflash/api/aiservice.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
2929
from codeflash.models.ExperimentMetadata import ExperimentMetadata
30-
from codeflash.models.models import AIServiceRefinerRequest
30+
from codeflash.models.models import AIServiceCodeRepairRequest, AIServiceRefinerRequest
3131
from codeflash.result.explanation import Explanation
3232

3333

@@ -294,6 +294,59 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
294294
console.rule()
295295
return []
296296

297+
def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest]) -> list[OptimizedCandidate]:
298+
"""Optimize the given python code for performance by making a request to the Django endpoint.
299+
300+
Args:
301+
request: A list of optimization candidate details for refinement
302+
303+
Returns:
304+
-------
305+
- List[OptimizationCandidate]: A list of Optimization Candidates.
306+
307+
"""
308+
payload = [
309+
{
310+
"optimization_id": opt.optimization_id,
311+
"original_source_code": opt.original_source_code,
312+
"modified_source_code": opt.modified_source_code,
313+
"trace_id": opt.trace_id,
314+
}
315+
for opt in request
316+
]
317+
# logger.debug(f"Repair {len(request)} optimizations…")
318+
console.rule()
319+
try:
320+
response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120)
321+
except requests.exceptions.RequestException as e:
322+
logger.exception(f"Error generating optimization repair: {e}")
323+
ph("cli-optimize-error-caught", {"error": str(e)})
324+
return []
325+
326+
if response.status_code == 200:
327+
refined_optimizations = response.json()["code_repairs"]
328+
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
329+
console.rule()
330+
331+
refinements = self._get_valid_candidates(refined_optimizations)
332+
return [
333+
OptimizedCandidate(
334+
source_code=c.source_code,
335+
explanation=c.explanation,
336+
optimization_id=c.optimization_id[:-4] + "cdrp",
337+
)
338+
for c in refinements
339+
]
340+
341+
try:
342+
error = response.json()["error"]
343+
except Exception:
344+
error = response.text
345+
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
346+
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
347+
console.rule()
348+
return []
349+
297350
def get_new_explanation( # noqa: D417
298351
self,
299352
source_code: str,

codeflash/models/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ class AIServiceRefinerRequest:
4848
function_references: str | None = None
4949

5050

51+
@dataclass(frozen=True)
52+
class AIServiceCodeRepairRequest:
53+
optimization_id: str
54+
original_source_code: str
55+
modified_source_code: str
56+
test_details: str
57+
trace_id: str
58+
59+
5160
# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully
5261
# qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name
5362
# of the module is foo.eggs.

codeflash/optimization/function_optimizer.py

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import TYPE_CHECKING
1414

1515
import libcst as cst
16+
import sentry_sdk
1617
from rich.console import Group
1718
from rich.panel import Panel
1819
from rich.syntax import Syntax
@@ -69,6 +70,7 @@
6970
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
7071
from codeflash.models.ExperimentMetadata import ExperimentMetadata
7172
from codeflash.models.models import (
73+
AIServiceCodeRepairRequest,
7274
BestOptimization,
7375
CodeOptimizationContext,
7476
GeneratedTests,
@@ -589,6 +591,28 @@ def determine_best_candidate(
589591
optimized_runtimes[candidate.optimization_id] = None
590592
is_correct[candidate.optimization_id] = False
591593
speedup_ratios[candidate.optimization_id] = None
594+
fail_value = run_results.value
595+
if (
596+
fail_value != "Test results did not match the test results of the original code."
597+
and len(future_all_refinements) <= 3
598+
and not candidate.optimization_id.endswith("cdrp")
599+
):
600+
# # queue corresponding code repair optimization for best optimization
601+
future_all_refinements.append(
602+
self.code_repair_optimizations(
603+
original_source_code=candidate,
604+
modified_source_code=code_context,
605+
original_code_baseline=original_code_baseline,
606+
test_details="test_details",
607+
code_context=code_context,
608+
trace_id=self.function_trace_id[:-4] + exp_type
609+
if self.experiment_id
610+
else self.function_trace_id,
611+
ai_service_client=ai_service_client,
612+
executor=self.executor,
613+
function_references=function_references,
614+
)
615+
)
592616
else:
593617
candidate_result: OptimizedCandidateResult = run_results.unwrap()
594618
best_test_runtime = candidate_result.best_test_runtime
@@ -672,21 +696,21 @@ def determine_best_candidate(
672696
async_throughput=candidate_result.async_throughput,
673697
)
674698
valid_optimizations.append(best_optimization)
675-
# queue corresponding refined optimization for best optimization
676-
if not candidate.optimization_id.endswith("refi"):
677-
future_all_refinements.append(
678-
self.refine_optimizations(
679-
valid_optimizations=[best_optimization],
680-
original_code_baseline=original_code_baseline,
681-
code_context=code_context,
682-
trace_id=self.function_trace_id[:-4] + exp_type
683-
if self.experiment_id
684-
else self.function_trace_id,
685-
ai_service_client=ai_service_client,
686-
executor=self.executor,
687-
function_references=function_references,
688-
)
689-
)
699+
# # queue corresponding refined optimization for best optimization
700+
# if not candidate.optimization_id.endswith("refi"):
701+
# future_all_refinements.append(
702+
# self.refine_optimizations(
703+
# valid_optimizations=[best_optimization],
704+
# original_code_baseline=original_code_baseline,
705+
# code_context=code_context,
706+
# trace_id=self.function_trace_id[:-4] + exp_type
707+
# if self.experiment_id
708+
# else self.function_trace_id,
709+
# ai_service_client=ai_service_client,
710+
# executor=self.executor,
711+
# function_references=function_references,
712+
# )
713+
# )
690714
else:
691715
# For async functions, prioritize throughput metrics over runtime even for slow candidates
692716
is_async = (
@@ -839,6 +863,26 @@ def refine_optimizations(
839863
]
840864
return executor.submit(ai_service_client.optimize_python_code_refinement, request=request)
841865

866+
def code_repair_optimizations(
867+
self,
868+
original_source_code: str,
869+
modified_source_code: str,
870+
test_details: str,
871+
trace_id: str,
872+
ai_service_client: AiServiceClient,
873+
executor: concurrent.futures.ThreadPoolExecutor,
874+
) -> concurrent.futures.Future:
875+
request = [
876+
AIServiceCodeRepairRequest(
877+
optimization_id="",
878+
original_source_code=original_source_code,
879+
modified_source_code=modified_source_code,
880+
test_details=test_details,
881+
trace_id=trace_id,
882+
)
883+
]
884+
return executor.submit(ai_service_client.optimize_python_code_repair, request=request)
885+
842886
def log_successful_optimization(
843887
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
844888
) -> None:
@@ -1813,6 +1857,7 @@ def run_optimized_candidate(
18131857
)
18141858
)
18151859
console.rule()
1860+
# print(type(code_context), type(candidate))
18161861
match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results)
18171862
if match:
18181863
logger.info("h3|Test results matched ✅")
@@ -1823,15 +1868,29 @@ def run_optimized_candidate(
18231868
# if the test unmatched percentage is greater than 50%, we can't fix it
18241869
return self.get_results_not_matched_error()
18251870

1826-
print(f"should try to fix it, diffs: {diffs}")
1827-
# with the parsed test results diff ask the llm to fix the candidate to match the test results of the original code, and run again
1828-
# self.run_optimized_candidate(
1829-
# optimization_candidate_index=optimization_candidate_index,
1830-
# baseline_results=baseline_results,
1831-
# original_helper_code=original_helper_code,
1832-
# file_path_to_helper_classes=file_path_to_helper_classes,
1833-
# )
1834-
return self.get_results_not_matched_error()
1871+
logger.info("running code repair...")
1872+
# not sure if all return types will be convertible to string
1873+
diff_per_test_fn = {}
1874+
for diff in diffs:
1875+
try:
1876+
diff_per_test_fn[diff.test_src_code] = (
1877+
diff_per_test_fn.setdefault(diff.test_src_code, "")
1878+
+ f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.pytest_error}\n"
1879+
)
1880+
1881+
except Exception as e:
1882+
sentry_sdk.capture_exception(e)
1883+
logger.exception(e)
1884+
return self.get_results_not_matched_error()
1885+
try:
1886+
test_issues = "\n".join(
1887+
f"{test_fn_def}\n{value}" for test_fn_def, value in diff_per_test_fn.items()
1888+
)
1889+
except Exception as e:
1890+
sentry_sdk.capture_exception(e)
1891+
logger.exception(e)
1892+
return self.get_results_not_matched_error()
1893+
return Failure(test_issues)
18351894

18361895
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
18371896

codeflash/verification/equivalence.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
118118
test_diff.original_value = original_test_result.stdout
119119
test_diff.candidate_value = cdd_test_result.stdout
120120
test_diffs.append(test_diff)
121-
break
122121

123122
if original_test_result.test_type in {
124123
TestType.EXISTING_UNIT_TEST,
@@ -130,7 +129,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
130129
test_diff.original_value = original_test_result.did_pass
131130
test_diff.candidate_value = cdd_test_result.did_pass
132131
test_diffs.append(test_diff)
133-
break
132+
134133
sys.setrecursionlimit(original_recursion_limit)
135134
if did_all_timeout:
136135
return False, test_diffs

0 commit comments

Comments
 (0)