3737 has_any_async_functions ,
3838 module_name_from_file_path ,
3939 restore_conftest ,
40+ diff_length ,
41+ create_rank_dictionary_compact ,
4042)
4143from codeflash .code_utils .config_consts import (
4244 INDIVIDUAL_TESTCASE_TIMEOUT ,
@@ -369,6 +371,7 @@ def determine_best_candidate(
369371 speedup_ratios : dict [str , float | None ] = {}
370372 optimized_runtimes : dict [str , float | None ] = {}
371373 is_correct = {}
374+ optimized_line_profiler_results : dict [str , str ] = {}
372375
373376 logger .info (
374377 f"Determining best optimization candidate (out of { len (candidates )} ) for "
@@ -464,7 +467,7 @@ def determine_best_candidate(
464467 candidate_result , original_code_baseline .runtime , best_runtime_until_now = None
465468 ) and quantity_of_tests_critic (candidate_result ):
466469 tree .add (
467- "This candidate is faster than the previous best candidate . 🚀"
470+ "This candidate is faster than the original code . 🚀"
468471 ) # TODO: Change this description
469472 tree .add (f"Original summed runtime: { humanize_runtime (original_code_baseline .runtime )} " )
470473 tree .add (
@@ -479,6 +482,7 @@ def determine_best_candidate(
479482 original_helper_code = original_helper_code ,
480483 candidate_index = candidate_index ,
481484 )
485+ optimized_line_profiler_results [candidate .optimization_id ]= line_profile_test_results ['str_out' ]
482486 replay_perf_gain = {}
483487 if self .args .benchmark :
484488 test_results_by_benchmark = (
@@ -547,8 +551,8 @@ def determine_best_candidate(
547551 trace_id = self .function_trace_id
548552 if trace_id .endswith (("EXP0" , "EXP1" )):
549553 trace_id = trace_id [:- 4 ] + exp_type
550- # refinement_dict is a dictionary with optimization_id as a key and the refined code as a value
551- refinement_dict = self .refine_optimizations (
554+ # refinement_response is a dataclass with optimization_id, code and explanation
555+ refinement_response = self .refine_optimizations (
552556 valid_optimizations = self .valid_optimizations ,
553557 original_code_baseline = original_code_baseline ,
554558 code_context = code_context ,
@@ -562,23 +566,9 @@ def determine_best_candidate(
562566 executor = executor ,
563567 fto_name = self .function_to_optimize .qualified_name ,
564568 )
565-
566- more_opt_candidates = [
567- OptimizedCandidate (
568- source_code = code ,
569- explanation = self .valid_optimizations [
570- i
571- ].candidate .explanation , # TODO: handle the new explanation after the refinement
572- optimization_id = opt_id ,
573- )
574- for i , (opt_id , code ) in enumerate (refinement_dict .items ())
575- # filter out empty strings of code
576- if code != ""
577- ]
578- # we no longer need to apply diffs since we are generating the entire code again
579- candidates .extend (more_opt_candidates )
580- print ("added candidates from refinement" )
581- original_len += len (more_opt_candidates )
569+ candidates .extend (refinement_response )
570+ print ("Added candidates from refinement" )
571+ original_len += len (refinement_response )
582572 refinement_done = True
583573 except KeyboardInterrupt as e :
584574 self .write_code_and_helpers (
@@ -587,58 +577,17 @@ def determine_best_candidate(
587577 logger .exception (f"Optimization interrupted: { e } " )
588578 raise
589579
590- def diff_length (a : str , b : str ) -> int :
591- """Compute the length (in characters) of the unified diff between two strings.
592-
593- Args:
594- a (str): Original string.
595- b (str): Modified string.
596-
597- Returns:
598- int: Total number of characters in the diff.
599-
600- """
601- # Split input strings into lines for line-by-line diff
602- a_lines = a .splitlines (keepends = True )
603- b_lines = b .splitlines (keepends = True )
604-
605- # Compute unified diff
606- diff_lines = list (difflib .unified_diff (a_lines , b_lines , lineterm = "" ))
607-
608- # Join all lines with newline to calculate total diff length
609- diff_text = "\n " .join (diff_lines )
610-
611- return len (diff_text )
612-
613- def create_rank_dictionary_compact (int_array : list [int ]) -> dict [int , int ]:
614- """Create a dictionary from a list of ints, mapping the original index to its rank.
615-
616- This version uses a more compact, "Pythonic" implementation.
617-
618- Args:
619- int_array: A list of integers.
620-
621- Returns:
622- A dictionary where keys are original indices and values are the
623- rank of the element in ascending order.
624-
625- """
626- # Sort the indices of the array based on their corresponding values
627- sorted_indices = sorted (range (len (int_array )), key = lambda i : int_array [i ])
628-
629- # Create a dictionary mapping the original index to its rank (its position in the sorted list)
630- return {original_index : rank for rank , original_index in enumerate (sorted_indices )}
631-
632580 if not len (self .valid_optimizations ):
633581 return None
634582 # need to figure out the best candidate here before we return best_optimization
635- diff_lens_list = []
583+ diff_lens_list = [] # character level diff
636584 runtimes_list = []
637585 for valid_opt in self .valid_optimizations :
638- diff_lens_list .append (diff_length (valid_opt .candidate .source_code , code_context .read_writable_code ))
586+ diff_lens_list .append (diff_length (valid_opt .candidate .source_code , code_context .read_writable_code )) #char level diff
639587 runtimes_list .append (valid_opt .runtime )
640588 diff_lens_ranking = create_rank_dictionary_compact (diff_lens_list )
641589 runtimes_ranking = create_rank_dictionary_compact (runtimes_list )
590+ # TODO: better way to resolve conflicts with same min ranking
642591 overall_ranking = {key : diff_lens_ranking [key ] + runtimes_ranking [key ] for key in diff_lens_ranking .keys ()} # noqa: SIM118
643592 min_key = min (overall_ranking , key = overall_ranking .get )
644593 best_optimization = self .valid_optimizations [min_key ]
@@ -649,6 +598,7 @@ def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
649598 optimized_runtime = optimized_runtimes ,
650599 is_correct = is_correct ,
651600 best_optimization_id = best_optimization .candidate .optimization_id ,
601+ optimized_line_profiler_results = optimized_line_profiler_results
652602 )
653603 return best_optimization
654604
@@ -662,7 +612,7 @@ def refine_optimizations(
662612 ai_service_client : AiServiceClient ,
663613 executor : concurrent .futures .ThreadPoolExecutor ,
664614 fto_name : str ,
665- ) -> dict [ str , str ]:
615+ ) -> list [ OptimizedCandidate ]:
666616 request = [
667617 AIServiceRefinerRequest (
668618 optimization_id = opt .candidate .optimization_id ,
@@ -680,7 +630,7 @@ def refine_optimizations(
680630 fto_name = fto_name ,
681631 )
682632 for opt in valid_optimizations
683- ]
633+ ] # TODO: multiple workers for this?
684634 future_refinement_results = executor .submit (ai_service_client .optimize_python_code_refinement , request = request )
685635 concurrent .futures .wait ([future_refinement_results ])
686636 return future_refinement_results .result ()
0 commit comments