@@ -126,7 +126,8 @@ def __init__(
126126 self ,
127127 initial_candidates : list ,
128128 future_line_profile_results : concurrent .futures .Future ,
129- future_all_refinements : list ,
129+ future_all_refinements : list [concurrent .futures .Future ],
130+ future_all_code_repair : list [concurrent .futures .Future ],
130131 ) -> None :
131132 self .candidate_queue = queue .Queue ()
132133 self .line_profiler_done = False
@@ -139,6 +140,7 @@ def __init__(
139140
140141 self .future_line_profile_results = future_line_profile_results
141142 self .future_all_refinements = future_all_refinements
143+ self .future_all_code_repair = future_all_code_repair
142144
143145 def get_next_candidate (self ) -> OptimizedCandidate | None :
144146 """Get the next candidate from the queue, handling async results as needed."""
@@ -151,6 +153,8 @@ def _handle_empty_queue(self) -> OptimizedCandidate | None:
151153 """Handle empty queue by checking for pending async results."""
152154 if not self .line_profiler_done :
153155 return self ._process_line_profiler_results ()
156+ if len (self .future_all_code_repair ) > 0 :
157+ return self ._process_code_repair ()
154158 if self .line_profiler_done and not self .refinement_done :
155159 return self ._process_refinement_results ()
156160 return None # All done
@@ -190,10 +194,30 @@ def _process_refinement_results(self) -> OptimizedCandidate | None:
190194 logger .info (
191195 f"Added { len (refinement_response )} candidates from refinement, total candidates now: { self .candidate_len } "
192196 )
197+ self .future_all_refinements = []
193198 self .refinement_done = True
194199
195200 return self .get_next_candidate ()
196201
202+ def _process_code_repair (self ) -> OptimizedCandidate | None :
203+ logger .info (f"loading|Repairing { len (self .future_all_code_repair )} candidates" )
204+ concurrent .futures .wait (self .future_all_code_repair )
205+ candidates_added = 0
206+ for future_code_repair in self .future_all_code_repair :
207+ possible_code_repair = future_code_repair .result ()
208+ if possible_code_repair :
209+ self .candidate_queue .put (possible_code_repair )
210+ self .candidate_len += 1
211+ candidates_added += 1
212+
213+ if candidates_added > 0 :
214+ logger .info (
215+ f"Added { candidates_added } candidates from code repair, total candidates now: { self .candidate_len } "
216+ )
217+ self .future_all_code_repair = []
218+
219+ return self .get_next_candidate ()
220+
197221 def is_done (self ) -> bool :
198222 """Check if processing is complete."""
199223 return self .line_profiler_done and self .refinement_done and self .candidate_queue .empty ()
@@ -250,6 +274,8 @@ def __init__(
250274 )
251275 self .optimization_review = ""
252276 self .ast_code_to_id = {}
277+ self .future_all_refinements : list [concurrent .futures .Future ] = []
278+ self .future_all_code_repair : list [concurrent .futures .Future ] = []
253279
254280 def can_be_optimized (self ) -> Result [tuple [bool , CodeOptimizationContext , dict [Path , str ]], str ]:
255281 should_run_experiment = self .experiment_id is not None
@@ -528,8 +554,10 @@ def determine_best_candidate(
528554 )
529555 console .rule ()
530556
531- future_all_refinements : list [concurrent .futures .Future ] = []
532557 self .ast_code_to_id .clear ()
558+ self .future_all_refinements .clear ()
559+ self .future_all_code_repair .clear ()
560+
533561 valid_optimizations = []
534562 optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
535563
@@ -550,7 +578,9 @@ def determine_best_candidate(
550578 )
551579
552580 # Initialize candidate processor
553- processor = CandidateProcessor (candidates , future_line_profile_results , future_all_refinements )
581+ processor = CandidateProcessor (
582+ candidates , future_line_profile_results , self .future_all_refinements , self .future_all_code_repair
583+ )
554584 candidate_index = 0
555585
556586 # Process candidates using queue-based approach
@@ -609,10 +639,8 @@ def determine_best_candidate(
609639 "shorter_source_code" : candidate .source_code ,
610640 "diff_len" : diff_length (candidate .source_code .flat , code_context .read_writable_code .flat ),
611641 }
612- self .reset_optimization_metrics_for_candidate (
613- candidate .optimization_id , speedup_ratios , is_correct , optimized_runtimes
614- )
615- run_results , new_candidate = self .run_optimized_candidate (
642+
643+ run_results = self .run_optimized_candidate (
616644 optimization_candidate_index = candidate_index ,
617645 baseline_results = original_code_baseline ,
618646 original_helper_code = original_helper_code ,
@@ -621,9 +649,6 @@ def determine_best_candidate(
621649 candidate = candidate ,
622650 exp_type = exp_type ,
623651 )
624- if candidate .optimization_id != new_candidate .optimization_id :
625- # override the candidate if the optimization_id has changed, this may happen if the candidate was modified by the code-repair
626- candidate = new_candidate
627652
628653 console .rule ()
629654 if not is_successful (run_results ):
@@ -715,7 +740,7 @@ def determine_best_candidate(
715740 valid_optimizations .append (best_optimization )
716741 # queue corresponding refined optimization for best optimization
717742 if not candidate .optimization_id .endswith ("refi" ):
718- future_all_refinements .append (
743+ self . future_all_refinements .append (
719744 self .refine_optimizations (
720745 valid_optimizations = [best_optimization ],
721746 original_code_baseline = original_code_baseline ,
@@ -880,23 +905,24 @@ def refine_optimizations(
880905 ]
881906 return executor .submit (ai_service_client .optimize_python_code_refinement , request = request )
882907
883- def code_repair_optimizations (
908+ def repair_optimization (
884909 self ,
885910 original_source_code : str ,
886911 modified_source_code : str ,
887912 test_diffs : list [TestDiff ],
888913 trace_id : str ,
889914 optimization_id : str ,
890915 ai_service_client : AiServiceClient ,
891- ) -> OptimizedCandidate | None :
916+ executor : concurrent .futures .ThreadPoolExecutor ,
917+ ) -> concurrent .futures .Future [OptimizedCandidate | None ]:
892918 request = AIServiceCodeRepairRequest (
893919 optimization_id = optimization_id ,
894920 original_source_code = original_source_code ,
895921 modified_source_code = modified_source_code ,
896922 test_diffs = test_diffs ,
897923 trace_id = trace_id ,
898924 )
899- return ai_service_client .optimize_python_code_repair ( request = request )
925+ return executor . submit ( ai_service_client .optimize_python_code_repair , request = request )
900926
901927 def log_successful_optimization (
902928 self , explanation : Explanation , generated_tests : GeneratedTestsList , exp_type : str
@@ -1816,7 +1842,7 @@ def get_results_not_matched_error(self) -> Failure:
18161842 console .rule ()
18171843 return Failure ("Test results did not match the test results of the original code." )
18181844
1819- def run_optimized_candidate ( # noqa: PLR0911
1845+ def run_optimized_candidate (
18201846 self ,
18211847 * ,
18221848 optimization_candidate_index : int ,
@@ -1826,7 +1852,7 @@ def run_optimized_candidate( # noqa: PLR0911
18261852 code_context : CodeOptimizationContext ,
18271853 candidate : OptimizedCandidate ,
18281854 exp_type : str ,
1829- ) -> tuple [ Result [OptimizedCandidateResult , str ], OptimizedCandidate ]:
1855+ ) -> Result [OptimizedCandidateResult , str ]:
18301856 assert (test_framework := self .args .test_framework ) in {"pytest" , "unittest" } # noqa: RUF018
18311857
18321858 with progress_bar ("Testing optimization candidate" ):
@@ -1884,16 +1910,16 @@ def run_optimized_candidate( # noqa: PLR0911
18841910 result_unmatched_perc = len (diffs ) / len (candidate_behavior_results )
18851911 if result_unmatched_perc > 0.5 :
18861912 # if the test unmatched percentage is greater than 50%, we can't fix it
1887- return self .get_results_not_matched_error (), candidate
1913+ return self .get_results_not_matched_error ()
18881914
18891915 if candidate .optimization_id .endswith ("cdrp" ):
18901916 # prevent looping for now
1891- return self .get_results_not_matched_error (), candidate
1917+ return self .get_results_not_matched_error ()
18921918
18931919 ai_service_client = self .aiservice_client if exp_type == "EXP0" else self .local_aiservice_client
1894-
1895- with progress_bar ( "Some of the test results are not matching, let me see if I can fix this" ):
1896- new_candidate = self .code_repair_optimizations (
1920+ logger . info ( "Adding this to the repair queue" )
1921+ self . future_all_code_repair . append (
1922+ self .repair_optimization (
18971923 original_source_code = code_context .read_writable_code .markdown ,
18981924 modified_source_code = candidate .source_code .markdown ,
18991925 test_diffs = diffs ,
@@ -1902,51 +1928,11 @@ def run_optimized_candidate( # noqa: PLR0911
19021928 else self .function_trace_id ,
19031929 ai_service_client = ai_service_client ,
19041930 optimization_id = candidate .optimization_id ,
1931+ executor = self .executor ,
19051932 )
1906- if not new_candidate :
1907- return Failure ("Code repair failed to generate a valid candidate." ), candidate
1908-
1909- code_print (
1910- new_candidate .source_code .flat ,
1911- file_name = f"candidate_{ optimization_candidate_index } .py" ,
1912- function_name = self .function_to_optimize .function_name ,
19131933 )
1914- normalized_code = normalize_code (new_candidate .source_code .flat .strip ())
1915- self .ast_code_to_id [normalized_code ] = {
1916- "optimization_id" : new_candidate .optimization_id ,
1917- "shorter_source_code" : new_candidate .source_code ,
1918- "diff_len" : diff_length (new_candidate .source_code .flat , code_context .read_writable_code .flat ),
1919- }
19201934
1921- try :
1922- # revert first to original code then replace with new repaired code, so we don't get any weird behavior
1923- self .write_code_and_helpers (
1924- self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
1925- )
1926- did_update = self .replace_function_and_helpers_with_optimized_code (
1927- code_context = code_context ,
1928- optimized_code = new_candidate .source_code ,
1929- original_helper_code = original_helper_code ,
1930- )
1931- if did_update :
1932- return self .run_optimized_candidate (
1933- optimization_candidate_index = optimization_candidate_index ,
1934- baseline_results = baseline_results ,
1935- original_helper_code = original_helper_code ,
1936- file_path_to_helper_classes = file_path_to_helper_classes ,
1937- code_context = code_context ,
1938- candidate = new_candidate ,
1939- exp_type = exp_type ,
1940- )
1941- msg = "No functions were replaced in the optimized code. Skipping optimization candidate."
1942- logger .warning (f"force_lsp|{ msg } " )
1943- return Failure (msg ), candidate
1944- except (ValueError , SyntaxError , cst .ParserSyntaxError , AttributeError ) as e :
1945- logger .error (e )
1946- self .write_code_and_helpers (
1947- self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
1948- )
1949- return Failure ("Code repair failed to generate a valid candidate." ), candidate
1935+ return self .get_results_not_matched_error ()
19501936
19511937 logger .info (f"loading|Running performance tests for candidate { optimization_candidate_index } ..." )
19521938
@@ -2038,7 +2024,7 @@ def run_optimized_candidate( # noqa: PLR0911
20382024 total_candidate_timing = total_candidate_timing ,
20392025 async_throughput = candidate_async_throughput ,
20402026 )
2041- ), candidate
2027+ )
20422028
20432029 def run_and_parse_tests (
20442030 self ,
0 commit comments