@@ -280,6 +280,7 @@ def __init__(
280280 max_workers = n_tests + 3 if self .experiment_id is None else n_tests + 4
281281 )
282282 self .optimization_review = ""
283+ self .ast_code_to_id = {}
283284 # SQLite database setup for logging
284285 self .code_repair_log_db = Path (__file__ ).parent / "code_repair_logs_cf.db"
285286
@@ -519,7 +520,7 @@ def determine_best_candidate(
519520 console .rule ()
520521
521522 future_all_refinements : list [concurrent .futures .Future ] = []
522- ast_code_to_id = {}
523+ self . ast_code_to_id . clear ()
523524 valid_optimizations = []
524525 optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
525526
@@ -598,11 +599,11 @@ def determine_best_candidate(
598599 continue
599600 # check if this code has been evaluated before by checking the ast normalized code string
600601 normalized_code = normalize_code (candidate .source_code .flat .strip ())
601- if normalized_code in ast_code_to_id :
602+ if normalized_code in self . ast_code_to_id :
602603 logger .info (
603604 "Current candidate has been encountered before in testing, Skipping optimization candidate."
604605 )
605- past_opt_id = ast_code_to_id [normalized_code ]["optimization_id" ]
606+ past_opt_id = self . ast_code_to_id [normalized_code ]["optimization_id" ]
606607 # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes
607608 speedup_ratios [candidate .optimization_id ] = speedup_ratios [past_opt_id ]
608609 is_correct [candidate .optimization_id ] = is_correct [past_opt_id ]
@@ -612,16 +613,18 @@ def determine_best_candidate(
612613 optimized_line_profiler_results [candidate .optimization_id ] = optimized_line_profiler_results [
613614 past_opt_id
614615 ]
615- optimizations_post [candidate .optimization_id ] = ast_code_to_id [normalized_code ][
616+ optimizations_post [candidate .optimization_id ] = self .ast_code_to_id [normalized_code ][
617+ "shorter_source_code"
618+ ].markdown
619+ optimizations_post [past_opt_id ] = self .ast_code_to_id [normalized_code ][
616620 "shorter_source_code"
617621 ].markdown
618- optimizations_post [past_opt_id ] = ast_code_to_id [normalized_code ]["shorter_source_code" ].markdown
619622 new_diff_len = diff_length (candidate .source_code .flat , code_context .read_writable_code .flat )
620623 if (
621- new_diff_len < ast_code_to_id [normalized_code ]["diff_len" ]
624+ new_diff_len < self . ast_code_to_id [normalized_code ]["diff_len" ]
622625 ): # new candidate has a shorter diff than the previously encountered one
623- ast_code_to_id [normalized_code ]["shorter_source_code" ] = candidate .source_code
624- ast_code_to_id [normalized_code ]["diff_len" ] = new_diff_len
626+ self . ast_code_to_id [normalized_code ]["shorter_source_code" ] = candidate .source_code
627+ self . ast_code_to_id [normalized_code ]["diff_len" ] = new_diff_len
625628 if candidate .optimization_id .endswith ("cdrp" ):
626629 log_code_repair_to_db (
627630 code_repair_log_db = self .code_repair_log_db ,
@@ -636,7 +639,7 @@ def determine_best_candidate(
636639 else "no" ,
637640 )
638641 continue
639- ast_code_to_id [normalized_code ] = {
642+ self . ast_code_to_id [normalized_code ] = {
640643 "optimization_id" : candidate .optimization_id ,
641644 "shorter_source_code" : candidate .source_code ,
642645 "diff_len" : diff_length (candidate .source_code .flat , code_context .read_writable_code .flat ),
@@ -657,6 +660,9 @@ def determine_best_candidate(
657660 speedup_ratios [candidate .optimization_id ] = None
658661 else :
659662 candidate_result : OptimizedCandidateResult = run_results .unwrap ()
663+ # override the candidate if the optimization_id has changed, this may happen if the candidate was modified by the code-repair
664+ if candidate .optimization_id != candidate_result .optimized_candidate .optimization_id :
665+ candidate = candidate_result .optimized_candidate
660666 best_test_runtime = candidate_result .best_test_runtime
661667 optimized_runtimes [candidate .optimization_id ] = best_test_runtime
662668 is_correct [candidate .optimization_id ] = True
@@ -821,7 +827,7 @@ def determine_best_candidate(
821827 for valid_opt in valid_optimizations :
822828 valid_opt_normalized_code = normalize_code (valid_opt .candidate .source_code .flat .strip ())
823829 new_candidate_with_shorter_code = OptimizedCandidate (
824- source_code = ast_code_to_id [valid_opt_normalized_code ]["shorter_source_code" ],
830+ source_code = self . ast_code_to_id [valid_opt_normalized_code ]["shorter_source_code" ],
825831 optimization_id = valid_opt .candidate .optimization_id ,
826832 explanation = valid_opt .candidate .explanation ,
827833 )
@@ -1946,7 +1952,18 @@ def run_optimized_candidate(
19461952
19471953 code_print (new_candidate .source_code .flat )
19481954
1955+ normalized_code = normalize_code (candidate .source_code .flat .strip ())
1956+ self .ast_code_to_id [normalized_code ] = {
1957+ "optimization_id" : candidate .optimization_id ,
1958+ "shorter_source_code" : candidate .source_code ,
1959+ "diff_len" : diff_length (candidate .source_code .flat , code_context .read_writable_code .flat ),
1960+ }
1961+
19491962 try :
1963+ # revert first to original code then replace with new repaired code, so we don't get any weird behavior
1964+ self .write_code_and_helpers (
1965+ self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
1966+ )
19501967 did_update = self .replace_function_and_helpers_with_optimized_code (
19511968 code_context = code_context ,
19521969 optimized_code = new_candidate .source_code ,
@@ -2048,6 +2065,7 @@ def run_optimized_candidate(
20482065 )
20492066 return Success (
20502067 OptimizedCandidateResult (
2068+ optimized_candidate = candidate ,
20512069 max_loop_count = loop_count ,
20522070 best_test_runtime = total_candidate_timing ,
20532071 behavior_test_results = candidate_behavior_results ,
0 commit comments