@@ -370,6 +370,7 @@ def determine_best_candidate(
370370 )
371371 console .rule ()
372372 candidates = deque (candidates )
373+ refinement_done = False
373374 # Start a new thread for AI service request, start loop in main thread
374375 # check if aiservice request is complete, when it is complete, append result to the candidates list
375376 with concurrent .futures .ThreadPoolExecutor (max_workers = 2 ) as executor :
@@ -535,7 +536,7 @@ def determine_best_candidate(
535536 )
536537 future_line_profile_results = None
537538
538- if len (candidates ) == 0 and len (self .valid_optimizations ) > 0 :
539+ if len (candidates ) == 0 and len (self .valid_optimizations ) > 0 and not refinement_done :
539540 # TODO: Instead of doing it all at once at the end, do it one by one as the optimizations
540541 # are found. This way we can hide the time waiting for the LLM results.
541542 refinement_diffs = self .refine_optimizations (
@@ -551,15 +552,20 @@ def determine_best_candidate(
551552 ai_service_client = ai_service_client ,
552553 executor = executor ,
553554 )
554-
555- print ("hi" )
555+ more_opt_candidates = [OptimizedCandidate (source_code = refinement_diffs [i ], explanation = self .valid_optimizations [i ].candidate .explanation , optimization_id = self .valid_optimizations [i ].candidate .optimization_id ) for i in range (len (refinement_diffs ))]
556+ # we no longer need to apply diffs since we are generating the entire code again
557+ candidates .extend (more_opt_candidates )
558+ print ("added candidates from refinement" )
559+ original_len += len (more_opt_candidates )
560+ refinement_done = True
556561 except KeyboardInterrupt as e :
557562 self .write_code_and_helpers (
558563 self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
559564 )
560565 logger .exception (f"Optimization interrupted: { e } " )
561566 raise
562567
568+ #need to figure out best candidate here before we return best_optimization
563569 ai_service_client .log_results (
564570 function_trace_id = self .function_trace_id [:- 4 ] + exp_type if self .experiment_id else self .function_trace_id ,
565571 speedup_ratio = speedup_ratios ,
0 commit comments