Skip to content

Commit b4474f3

Browse files
add code repairs to the queue
1 parent d66d2ce commit b4474f3

File tree

2 files changed

+53
-70
lines changed

2 files changed

+53
-70
lines changed

codeflash/api/aiservice.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -324,15 +324,12 @@ def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> Op
324324
fixed_optimization = response.json()
325325
console.rule()
326326

327-
if not self._get_valid_candidates([fixed_optimization]):
327+
valid_candidates = self._get_valid_candidates([fixed_optimization])
328+
if not valid_candidates:
328329
logger.error("Code repair failed to generate a valid candidate.")
329330
return None
330331

331-
return OptimizedCandidate(
332-
source_code=fixed_optimization["source_code"],
333-
explanation=fixed_optimization["explanation"],
334-
optimization_id=fixed_optimization["optimization_id"],
335-
)
332+
return valid_candidates[0]
336333

337334
try:
338335
error = response.json()["error"]

codeflash/optimization/function_optimizer.py

Lines changed: 50 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def __init__(
126126
self,
127127
initial_candidates: list,
128128
future_line_profile_results: concurrent.futures.Future,
129-
future_all_refinements: list,
129+
future_all_refinements: list[concurrent.futures.Future],
130+
future_all_code_repair: list[concurrent.futures.Future],
130131
) -> None:
131132
self.candidate_queue = queue.Queue()
132133
self.line_profiler_done = False
@@ -139,6 +140,7 @@ def __init__(
139140

140141
self.future_line_profile_results = future_line_profile_results
141142
self.future_all_refinements = future_all_refinements
143+
self.future_all_code_repair = future_all_code_repair
142144

143145
def get_next_candidate(self) -> OptimizedCandidate | None:
144146
"""Get the next candidate from the queue, handling async results as needed."""
@@ -151,6 +153,8 @@ def _handle_empty_queue(self) -> OptimizedCandidate | None:
151153
"""Handle empty queue by checking for pending async results."""
152154
if not self.line_profiler_done:
153155
return self._process_line_profiler_results()
156+
if len(self.future_all_code_repair) > 0:
157+
return self._process_code_repair()
154158
if self.line_profiler_done and not self.refinement_done:
155159
return self._process_refinement_results()
156160
return None # All done
@@ -190,10 +194,30 @@ def _process_refinement_results(self) -> OptimizedCandidate | None:
190194
logger.info(
191195
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
192196
)
197+
self.future_all_refinements = []
193198
self.refinement_done = True
194199

195200
return self.get_next_candidate()
196201

202+
def _process_code_repair(self) -> OptimizedCandidate | None:
203+
logger.info(f"loading|Repairing {len(self.future_all_code_repair)} candidates")
204+
concurrent.futures.wait(self.future_all_code_repair)
205+
candidates_added = 0
206+
for future_code_repair in self.future_all_code_repair:
207+
possible_code_repair = future_code_repair.result()
208+
if possible_code_repair:
209+
self.candidate_queue.put(possible_code_repair)
210+
self.candidate_len += 1
211+
candidates_added += 1
212+
213+
if candidates_added > 0:
214+
logger.info(
215+
f"Added {candidates_added} candidates from code repair, total candidates now: {self.candidate_len}"
216+
)
217+
self.future_all_code_repair = []
218+
219+
return self.get_next_candidate()
220+
197221
def is_done(self) -> bool:
198222
"""Check if processing is complete."""
199223
return self.line_profiler_done and self.refinement_done and self.candidate_queue.empty()
@@ -250,6 +274,8 @@ def __init__(
250274
)
251275
self.optimization_review = ""
252276
self.ast_code_to_id = {}
277+
self.future_all_refinements: list[concurrent.futures.Future] = []
278+
self.future_all_code_repair: list[concurrent.futures.Future] = []
253279

254280
def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
255281
should_run_experiment = self.experiment_id is not None
@@ -528,8 +554,10 @@ def determine_best_candidate(
528554
)
529555
console.rule()
530556

531-
future_all_refinements: list[concurrent.futures.Future] = []
532557
self.ast_code_to_id.clear()
558+
self.future_all_refinements.clear()
559+
self.future_all_code_repair.clear()
560+
533561
valid_optimizations = []
534562
optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
535563

@@ -550,7 +578,9 @@ def determine_best_candidate(
550578
)
551579

552580
# Initialize candidate processor
553-
processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements)
581+
processor = CandidateProcessor(
582+
candidates, future_line_profile_results, self.future_all_refinements, self.future_all_code_repair
583+
)
554584
candidate_index = 0
555585

556586
# Process candidates using queue-based approach
@@ -609,10 +639,8 @@ def determine_best_candidate(
609639
"shorter_source_code": candidate.source_code,
610640
"diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
611641
}
612-
self.reset_optimization_metrics_for_candidate(
613-
candidate.optimization_id, speedup_ratios, is_correct, optimized_runtimes
614-
)
615-
run_results, new_candidate = self.run_optimized_candidate(
642+
643+
run_results = self.run_optimized_candidate(
616644
optimization_candidate_index=candidate_index,
617645
baseline_results=original_code_baseline,
618646
original_helper_code=original_helper_code,
@@ -621,9 +649,6 @@ def determine_best_candidate(
621649
candidate=candidate,
622650
exp_type=exp_type,
623651
)
624-
if candidate.optimization_id != new_candidate.optimization_id:
625-
# override the candidate if the optimization_id has changed, this may happen if the candidate was modified by the code-repair
626-
candidate = new_candidate
627652

628653
console.rule()
629654
if not is_successful(run_results):
@@ -715,7 +740,7 @@ def determine_best_candidate(
715740
valid_optimizations.append(best_optimization)
716741
# queue corresponding refined optimization for best optimization
717742
if not candidate.optimization_id.endswith("refi"):
718-
future_all_refinements.append(
743+
self.future_all_refinements.append(
719744
self.refine_optimizations(
720745
valid_optimizations=[best_optimization],
721746
original_code_baseline=original_code_baseline,
@@ -880,23 +905,24 @@ def refine_optimizations(
880905
]
881906
return executor.submit(ai_service_client.optimize_python_code_refinement, request=request)
882907

883-
def code_repair_optimizations(
908+
def repair_optimization(
884909
self,
885910
original_source_code: str,
886911
modified_source_code: str,
887912
test_diffs: list[TestDiff],
888913
trace_id: str,
889914
optimization_id: str,
890915
ai_service_client: AiServiceClient,
891-
) -> OptimizedCandidate | None:
916+
executor: concurrent.futures.ThreadPoolExecutor,
917+
) -> concurrent.futures.Future[OptimizedCandidate | None]:
892918
request = AIServiceCodeRepairRequest(
893919
optimization_id=optimization_id,
894920
original_source_code=original_source_code,
895921
modified_source_code=modified_source_code,
896922
test_diffs=test_diffs,
897923
trace_id=trace_id,
898924
)
899-
return ai_service_client.optimize_python_code_repair(request=request)
925+
return executor.submit(ai_service_client.optimize_python_code_repair, request=request)
900926

901927
def log_successful_optimization(
902928
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
@@ -1816,7 +1842,7 @@ def get_results_not_matched_error(self) -> Failure:
18161842
console.rule()
18171843
return Failure("Test results did not match the test results of the original code.")
18181844

1819-
def run_optimized_candidate( # noqa: PLR0911
1845+
def run_optimized_candidate(
18201846
self,
18211847
*,
18221848
optimization_candidate_index: int,
@@ -1826,7 +1852,7 @@ def run_optimized_candidate( # noqa: PLR0911
18261852
code_context: CodeOptimizationContext,
18271853
candidate: OptimizedCandidate,
18281854
exp_type: str,
1829-
) -> tuple[Result[OptimizedCandidateResult, str], OptimizedCandidate]:
1855+
) -> Result[OptimizedCandidateResult, str]:
18301856
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018
18311857

18321858
with progress_bar("Testing optimization candidate"):
@@ -1884,16 +1910,16 @@ def run_optimized_candidate( # noqa: PLR0911
18841910
result_unmatched_perc = len(diffs) / len(candidate_behavior_results)
18851911
if result_unmatched_perc > 0.5:
18861912
# if the test unmatched percentage is greater than 50%, we can't fix it
1887-
return self.get_results_not_matched_error(), candidate
1913+
return self.get_results_not_matched_error()
18881914

18891915
if candidate.optimization_id.endswith("cdrp"):
18901916
# prevent looping for now
1891-
return self.get_results_not_matched_error(), candidate
1917+
return self.get_results_not_matched_error()
18921918

18931919
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
1894-
1895-
with progress_bar("Some of the test results are not matching, let me see if I can fix this"):
1896-
new_candidate = self.code_repair_optimizations(
1920+
logger.info("Adding this to the repair queue")
1921+
self.future_all_code_repair.append(
1922+
self.repair_optimization(
18971923
original_source_code=code_context.read_writable_code.markdown,
18981924
modified_source_code=candidate.source_code.markdown,
18991925
test_diffs=diffs,
@@ -1902,51 +1928,11 @@ def run_optimized_candidate( # noqa: PLR0911
19021928
else self.function_trace_id,
19031929
ai_service_client=ai_service_client,
19041930
optimization_id=candidate.optimization_id,
1931+
executor=self.executor,
19051932
)
1906-
if not new_candidate:
1907-
return Failure("Code repair failed to generate a valid candidate."), candidate
1908-
1909-
code_print(
1910-
new_candidate.source_code.flat,
1911-
file_name=f"candidate_{optimization_candidate_index}.py",
1912-
function_name=self.function_to_optimize.function_name,
19131933
)
1914-
normalized_code = normalize_code(new_candidate.source_code.flat.strip())
1915-
self.ast_code_to_id[normalized_code] = {
1916-
"optimization_id": new_candidate.optimization_id,
1917-
"shorter_source_code": new_candidate.source_code,
1918-
"diff_len": diff_length(new_candidate.source_code.flat, code_context.read_writable_code.flat),
1919-
}
19201934

1921-
try:
1922-
# revert first to original code then replace with new repaired code, so we don't get any weird behavior
1923-
self.write_code_and_helpers(
1924-
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1925-
)
1926-
did_update = self.replace_function_and_helpers_with_optimized_code(
1927-
code_context=code_context,
1928-
optimized_code=new_candidate.source_code,
1929-
original_helper_code=original_helper_code,
1930-
)
1931-
if did_update:
1932-
return self.run_optimized_candidate(
1933-
optimization_candidate_index=optimization_candidate_index,
1934-
baseline_results=baseline_results,
1935-
original_helper_code=original_helper_code,
1936-
file_path_to_helper_classes=file_path_to_helper_classes,
1937-
code_context=code_context,
1938-
candidate=new_candidate,
1939-
exp_type=exp_type,
1940-
)
1941-
msg = "No functions were replaced in the optimized code. Skipping optimization candidate."
1942-
logger.warning(f"force_lsp|{msg}")
1943-
return Failure(msg), candidate
1944-
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
1945-
logger.error(e)
1946-
self.write_code_and_helpers(
1947-
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1948-
)
1949-
return Failure("Code repair failed to generate a valid candidate."), candidate
1935+
return self.get_results_not_matched_error()
19501936

19511937
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
19521938

@@ -2038,7 +2024,7 @@ def run_optimized_candidate( # noqa: PLR0911
20382024
total_candidate_timing=total_candidate_timing,
20392025
async_throughput=candidate_async_throughput,
20402026
)
2041-
), candidate
2027+
)
20422028

20432029
def run_and_parse_tests(
20442030
self,

0 commit comments

Comments
 (0)