@@ -547,7 +547,8 @@ def determine_best_candidate(
547547 trace_id = self .function_trace_id
548548 if trace_id .endswith (("EXP0" , "EXP1" )):
549549 trace_id = trace_id [:- 4 ] + exp_type
550- refinement_diffs = self .refine_optimizations (
550+ # refinement_dict is a dictionary with optimization_id as a key and the refined code as a value
551+ refinement_dict = self .refine_optimizations (
551552 valid_optimizations = self .valid_optimizations ,
552553 original_code_baseline = original_code_baseline ,
553554 code_context = code_context ,
@@ -561,15 +562,18 @@ def determine_best_candidate(
561562 executor = executor ,
562563 fto_name = self .function_to_optimize .qualified_name ,
563564 )
564- # filter out empty strings of code
565+
565566 more_opt_candidates = [
566567 OptimizedCandidate (
567- source_code = refinement_diffs [i ],
568- explanation = self .valid_optimizations [i ].candidate .explanation ,
569- optimization_id = self .valid_optimizations [i ].candidate .optimization_id ,
568+ source_code = code ,
569+ explanation = self .valid_optimizations [
570+ i
571+ ].candidate .explanation , # TODO: handle the new explanation after the refinement
572+ optimization_id = opt_id ,
570573 )
571- for i in range (len (refinement_diffs ))
572- if refinement_diffs [i ] != ""
574+ for i , (opt_id , code ) in enumerate (refinement_dict .items ())
575+ # filter out empty strings of code
576+ if code != ""
573577 ]
574578 # we no longer need to apply diffs since we are generating the entire code again
575579 candidates .extend (more_opt_candidates )
@@ -637,14 +641,16 @@ def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
637641 runtimes_ranking = create_rank_dictionary_compact (runtimes_list )
638642 overall_ranking = {key : diff_lens_ranking [key ] + runtimes_ranking [key ] for key in diff_lens_ranking .keys ()} # noqa: SIM118
639643 min_key = min (overall_ranking , key = overall_ranking .get )
644+ best_optimization = self .valid_optimizations [min_key ]
640645 ai_service_client .log_results (
641646 function_trace_id = self .function_trace_id [:- 4 ] + exp_type if self .experiment_id else self .function_trace_id ,
642647 speedup_ratio = speedup_ratios ,
643648 original_runtime = original_code_baseline .runtime ,
644649 optimized_runtime = optimized_runtimes ,
645650 is_correct = is_correct ,
651+ metadata = {"best_optimization_id" : best_optimization .candidate .optimization_id },
646652 )
647- return self . valid_optimizations [ min_key ]
653+ return best_optimization
648654
649655 def refine_optimizations (
650656 self ,
@@ -656,9 +662,10 @@ def refine_optimizations(
656662 ai_service_client : AiServiceClient ,
657663 executor : concurrent .futures .ThreadPoolExecutor ,
658664 fto_name : str ,
659- ) -> list [ str ]:
665+ ) -> dict [ str , str ]:
660666 request = [
661667 AIServiceRefinerRequest (
668+ optimization_id = opt .candidate .optimization_id ,
662669 original_source_code = code_context .read_writable_code ,
663670 read_only_dependency_code = code_context .read_only_context_code ,
664671 original_code_runtime = humanize_runtime (original_code_baseline .runtime ),
0 commit comments