6767 GeneratedTests ,
6868 GeneratedTestsList ,
6969 OptimizationSet ,
70+ OptimizedCandidate ,
7071 OptimizedCandidateResult ,
7172 OriginalCodeBaseline ,
7273 TestFile ,
7374 TestFiles ,
7475 TestingMode ,
7576 TestResults ,
7677 TestType ,
77- OptimizedCandidate
7878)
7979from codeflash .result .create_pr import check_create_pr , existing_tests_source_for
8080from codeflash .result .critic import coverage_critic , performance_gain , quantity_of_tests_critic , speedup_critic
9494
9595 from codeflash .discovery .functions_to_optimize import FunctionToOptimize
9696 from codeflash .either import Result
97- from codeflash .models .models import (
98- BenchmarkKey ,
99- CoverageData ,
100- FunctionCalledInTest ,
101- FunctionSource ,
102- OptimizedCandidate ,
103- )
97+ from codeflash .models .models import BenchmarkKey , CoverageData , FunctionCalledInTest , FunctionSource
10498 from codeflash .verification .verification_utils import TestConfig
10599
106100
@@ -381,7 +375,7 @@ def determine_best_candidate(
381375 candidates = deque (candidates )
382376 refinement_done = False
383377 future_all_refinements : list [concurrent .futures .Future ] = []
384- ast_code_to_id = dict ()
378+ ast_code_to_id = {}
385379 # Start a new thread for AI service request, start loop in main thread
386380 # check if aiservice request is complete, when it is complete, append result to the candidates list
387381 ai_service_client = self .aiservice_client if exp_type == "EXP0" else self .local_aiservice_client
@@ -442,8 +436,11 @@ def determine_best_candidate(
442436 ast_code_to_id [normalized_code ]["shorter_source_code" ] = candidate .source_code
443437 ast_code_to_id [normalized_code ]["diff_len" ] = new_diff_len
444438 continue
445- else :
446- ast_code_to_id [normalized_code ] = {'optimization_id' :candidate .optimization_id , 'shorter_source_code' :candidate .source_code , 'diff_len' :diff_length (candidate .source_code , code_context .read_writable_code )}
439+ ast_code_to_id [normalized_code ] = {
440+ "optimization_id" : candidate .optimization_id ,
441+ "shorter_source_code" : candidate .source_code ,
442+ "diff_len" : diff_length (candidate .source_code , code_context .read_writable_code ),
443+ }
447444 run_results = self .run_optimized_candidate (
448445 optimization_candidate_index = candidate_index ,
449446 baseline_results = original_code_baseline ,
@@ -581,13 +578,17 @@ def determine_best_candidate(
581578 if not len (self .valid_optimizations ):
582579 return None
583580 # need to figure out the best candidate here before we return best_optimization
584- #reassign the shorter code here
581+ # reassign the shorter code here
585582 valid_candidates_with_shorter_code = []
586583 diff_lens_list = [] # character level diff
587584 runtimes_list = []
588585 for valid_opt in self .valid_optimizations :
589586 valid_opt_normalized_code = ast .unparse (ast .parse (valid_opt .candidate .source_code .strip ()))
590- new_candidate_with_shorter_code = OptimizedCandidate (source_code = ast_code_to_id [valid_opt_normalized_code ]["shorter_source_code" ], optimization_id = valid_opt .candidate .optimization_id , explanation = valid_opt .candidate .explanation )
587+ new_candidate_with_shorter_code = OptimizedCandidate (
588+ source_code = ast_code_to_id [valid_opt_normalized_code ]["shorter_source_code" ],
589+ optimization_id = valid_opt .candidate .optimization_id ,
590+ explanation = valid_opt .candidate .explanation ,
591+ )
591592 new_best_opt = BestOptimization (
592593 candidate = new_candidate_with_shorter_code ,
593594 helper_functions = valid_opt .helper_functions ,
@@ -610,7 +611,7 @@ def determine_best_candidate(
610611 overall_ranking = {key : diff_lens_ranking [key ] + runtimes_ranking [key ] for key in diff_lens_ranking .keys ()} # noqa: SIM118
611612 min_key = min (overall_ranking , key = overall_ranking .get )
612613 best_optimization = valid_candidates_with_shorter_code [min_key ]
613- #reassign code string which is the shortest
614+ # reassign code string which is the shortest
614615 ai_service_client .log_results (
615616 function_trace_id = self .function_trace_id [:- 4 ] + exp_type if self .experiment_id else self .function_trace_id ,
616617 speedup_ratio = speedup_ratios ,
0 commit comments