@@ -375,10 +375,10 @@ def determine_best_candidate(
375375 )
376376 console .rule ()
377377 candidates = deque (candidates )
378- refinement_done = False
378+ future_all_refinements : list [ concurrent . futures . Future ] = []
379379 # Start a new thread for AI service request, start loop in main thread
380380 # check if aiservice request is complete, when it is complete, append result to the candidates list
381- with concurrent .futures .ThreadPoolExecutor (max_workers = 2 ) as executor :
381+ with concurrent .futures .ThreadPoolExecutor (max_workers = 3 ) as executor :
382382 ai_service_client = self .aiservice_client if exp_type == "EXP0" else self .local_aiservice_client
383383 future_line_profile_results = executor .submit (
384384 ai_service_client .optimize_python_code_line_profiler ,
@@ -515,6 +515,19 @@ def determine_best_candidate(
515515 winning_replay_benchmarking_test_results = candidate_result .benchmarking_test_results ,
516516 )
517517 self .valid_optimizations .append (best_optimization )
518+ # queue corresponding refined optimization for best optimization
519+ future_all_refinements .append (
520+ self .refine_optimizations (
521+ valid_optimizations = [best_optimization ],
522+ original_code_baseline = original_code_baseline ,
523+ code_context = code_context ,
524+ trace_id = self .function_trace_id [:- 4 ] + exp_type
525+ if self .experiment_id
526+ else self .function_trace_id ,
527+ ai_service_client = ai_service_client ,
528+ executor = executor ,
529+ )
530+ )
518531 else :
519532 tree .add (
520533 f"Summed runtime: { humanize_runtime (best_test_runtime )} "
@@ -543,26 +556,16 @@ def determine_best_candidate(
543556 f"Added results from line profiler to candidates, total candidates now: { original_len } "
544557 )
545558 future_line_profile_results = None
546-
547- if len (candidates ) == 0 and len (self .valid_optimizations ) > 0 and not refinement_done :
548- # TODO: Instead of doing it all at once at the end, do it one by one as the optimizations
549- # are found. This way we can hide the time waiting for the LLM results.
550- trace_id = self .function_trace_id
551- if trace_id .endswith (("EXP0" , "EXP1" )):
552- trace_id = trace_id [:- 4 ] + exp_type
553- # refinement_response is a dataclass with optimization_id, code and explanation
554- refinement_response = self .refine_optimizations (
555- valid_optimizations = self .valid_optimizations ,
556- original_code_baseline = original_code_baseline ,
557- code_context = code_context ,
558- trace_id = trace_id ,
559- ai_service_client = ai_service_client ,
560- executor = executor ,
561- )
559+ # all original candidates and lp andidates processed
560+ if (not len (candidates )) and line_profiler_done :
561+ # waiting just in case not all calls are finished
562+ concurrent .futures .wait (future_all_refinements )
563+ refinement_response = [
564+ future_refinement .result () for future_refinement in future_all_refinements
565+ ]
562566 candidates .extend (refinement_response )
563567 print ("Added candidates from refinement" )
564568 original_len += len (refinement_response )
565- refinement_done = True
566569 except KeyboardInterrupt as e :
567570 self .write_code_and_helpers (
568571 self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
@@ -605,7 +608,7 @@ def refine_optimizations(
605608 trace_id : str ,
606609 ai_service_client : AiServiceClient ,
607610 executor : concurrent .futures .ThreadPoolExecutor ,
608- ) -> list [ OptimizedCandidate ] :
611+ ) -> concurrent . futures . Future :
609612 request = [
610613 AIServiceRefinerRequest (
611614 optimization_id = opt .candidate .optimization_id ,
@@ -621,10 +624,8 @@ def refine_optimizations(
621624 optimized_line_profiler_results = opt .line_profiler_test_results ["str_out" ],
622625 )
623626 for opt in valid_optimizations
624- ] # TODO: multiple workers for this?
625- future_refinement_results = executor .submit (ai_service_client .optimize_python_code_refinement , request = request )
626- concurrent .futures .wait ([future_refinement_results ])
627- return future_refinement_results .result ()
627+ ]
628+ return executor .submit (ai_service_client .optimize_python_code_refinement , request = request )
628629
629630 def log_successful_optimization (
630631 self , explanation : Explanation , generated_tests : GeneratedTestsList , exp_type : str
0 commit comments