@@ -147,7 +147,9 @@ def __init__(
147147 self .generate_and_instrument_tests_results : (
148148 tuple [GeneratedTestsList , dict [str , set [FunctionCalledInTest ]], OptimizationSet ] | None
149149 ) = None
150- self .valid_optimizations : list [BestOptimization ] = list () # TODO: Figure out the dataclass type for this
150+ self .valid_optimizations : list [BestOptimization ] = (
151+ list () # TODO: Figure out the dataclass type for this # noqa: C408
152+ )
151153
152154 def can_be_optimized (self ) -> Result [tuple [bool , CodeOptimizationContext , dict [Path , str ]], str ]:
153155 should_run_experiment = self .experiment_id is not None
@@ -362,7 +364,7 @@ def determine_best_candidate(
362364 from codeflash .models .models import OptimizedCandidate
363365
364366 best_optimization : BestOptimization | None = None
365- best_runtime_until_now = original_code_baseline .runtime
367+ _best_runtime_until_now = original_code_baseline .runtime
366368
367369 speedup_ratios : dict [str , float | None ] = {}
368370 optimized_runtimes : dict [str , float | None ] = {}
@@ -510,7 +512,6 @@ def determine_best_candidate(
510512 winning_replay_benchmarking_test_results = candidate_result .benchmarking_test_results ,
511513 )
512514 self .valid_optimizations .append (best_optimization )
513- best_runtime_until_now = best_test_runtime
514515 else :
515516 tree .add (
516517 f"Summed runtime: { humanize_runtime (best_test_runtime )} "
@@ -543,18 +544,22 @@ def determine_best_candidate(
543544 if len (candidates ) == 0 and len (self .valid_optimizations ) > 0 and not refinement_done :
544545 # TODO: Instead of doing it all at once at the end, do it one by one as the optimizations
545546 # are found. This way we can hide the time waiting for the LLM results.
547+ trace_id = self .function_trace_id
548+ if trace_id .endswith (("EXP0" , "EXP1" )):
549+ trace_id = trace_id [:- 4 ] + exp_type
546550 refinement_diffs = self .refine_optimizations (
547551 valid_optimizations = self .valid_optimizations ,
548552 original_code_baseline = original_code_baseline ,
549553 code_context = code_context ,
550- trace_id = self . function_trace_id [: - 4 ] + exp_type ,
554+ trace_id = trace_id ,
551555 experiment_metadata = ExperimentMetadata (
552556 id = self .experiment_id , group = "control" if exp_type == "EXP0" else "experiment"
553557 )
554558 if self .experiment_id
555559 else None ,
556560 ai_service_client = ai_service_client ,
557561 executor = executor ,
562+ fto_name = self .function_to_optimize .qualified_name ,
558563 )
559564 # filter out empty strings of code
560565 more_opt_candidates = [
@@ -581,13 +586,11 @@ def determine_best_candidate(
581586 def diff_length (a : str , b : str ) -> int :
582587 """Compute the length (in characters) of the unified diff between two strings.
583588
584- Parameters
585- ----------
589+ Args:
586590 a (str): Original string.
587591 b (str): Modified string.
588592
589- Returns
590- -------
593+ Returns:
591594 int: Total number of characters in the diff.
592595
593596 """
@@ -604,7 +607,8 @@ def diff_length(a: str, b: str) -> int:
604607 return len (diff_text )
605608
606609 def create_rank_dictionary_compact (int_array : list [int ]) -> dict [int , int ]:
607- """Creates a dictionary from a list of ints, mapping the original index to its rank.
610+ """Create a dictionary from a list of ints, mapping the original index to its rank.
611+
608612 This version uses a more compact, "Pythonic" implementation.
609613
610614 Args:
@@ -631,7 +635,7 @@ def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
631635 runtimes_list .append (valid_opt .runtime )
632636 diff_lens_ranking = create_rank_dictionary_compact (diff_lens_list )
633637 runtimes_ranking = create_rank_dictionary_compact (runtimes_list )
634- overall_ranking = {key : diff_lens_ranking [key ] + runtimes_ranking [key ] for key in diff_lens_ranking .keys ()}
638+ overall_ranking = {key : diff_lens_ranking [key ] + runtimes_ranking [key ] for key in diff_lens_ranking .keys ()} # noqa: SIM118
635639 min_key = min (overall_ranking , key = overall_ranking .get )
636640 ai_service_client .log_results (
637641 function_trace_id = self .function_trace_id [:- 4 ] + exp_type if self .experiment_id else self .function_trace_id ,
@@ -651,6 +655,7 @@ def refine_optimizations(
651655 experiment_metadata : ExperimentMetadata | None ,
652656 ai_service_client : AiServiceClient ,
653657 executor : concurrent .futures .ThreadPoolExecutor ,
658+ fto_name : str ,
654659 ) -> list [str ]:
655660 request = [
656661 AIServiceRefinerRequest (
@@ -665,13 +670,13 @@ def refine_optimizations(
665670 original_line_profiler_results = original_code_baseline .line_profile_results ["str_out" ],
666671 optimized_line_profiler_results = opt .line_profiler_test_results ["str_out" ],
667672 experiment_metadata = experiment_metadata ,
673+ fto_name = fto_name ,
668674 )
669675 for opt in valid_optimizations
670676 ]
671677 future_refinement_results = executor .submit (ai_service_client .optimize_python_code_refinement , request = request )
672678 concurrent .futures .wait ([future_refinement_results ])
673- refinement_results = future_refinement_results .result ()
674- return refinement_results
679+ return future_refinement_results .result ()
675680
676681 def log_successful_optimization (
677682 self , explanation : Explanation , generated_tests : GeneratedTestsList , exp_type : str
0 commit comments