Skip to content

Commit 97f2426

Browse files
fix override candidate after the code repair
1 parent 696448c commit 97f2426

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

codeflash/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ class CodeContextType(str, Enum):
302302

303303

304304
class OptimizedCandidateResult(BaseModel):
305+
optimized_candidate: OptimizedCandidate
305306
max_loop_count: int
306307
best_test_runtime: int
307308
behavior_test_results: TestResults

codeflash/optimization/function_optimizer.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def __init__(
280280
max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4
281281
)
282282
self.optimization_review = ""
283+
self.ast_code_to_id = {}
283284
# SQLite database setup for logging
284285
self.code_repair_log_db = Path(__file__).parent / "code_repair_logs_cf.db"
285286

@@ -519,7 +520,7 @@ def determine_best_candidate(
519520
console.rule()
520521

521522
future_all_refinements: list[concurrent.futures.Future] = []
522-
ast_code_to_id = {}
523+
self.ast_code_to_id.clear()
523524
valid_optimizations = []
524525
optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
525526

@@ -598,11 +599,11 @@ def determine_best_candidate(
598599
continue
599600
# check if this code has been evaluated before by checking the ast normalized code string
600601
normalized_code = normalize_code(candidate.source_code.flat.strip())
601-
if normalized_code in ast_code_to_id:
602+
if normalized_code in self.ast_code_to_id:
602603
logger.info(
603604
"Current candidate has been encountered before in testing, Skipping optimization candidate."
604605
)
605-
past_opt_id = ast_code_to_id[normalized_code]["optimization_id"]
606+
past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"]
606607
# update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes
607608
speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id]
608609
is_correct[candidate.optimization_id] = is_correct[past_opt_id]
@@ -612,16 +613,18 @@ def determine_best_candidate(
612613
optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[
613614
past_opt_id
614615
]
615-
optimizations_post[candidate.optimization_id] = ast_code_to_id[normalized_code][
616+
optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][
617+
"shorter_source_code"
618+
].markdown
619+
optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code][
616620
"shorter_source_code"
617621
].markdown
618-
optimizations_post[past_opt_id] = ast_code_to_id[normalized_code]["shorter_source_code"].markdown
619622
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
620623
if (
621-
new_diff_len < ast_code_to_id[normalized_code]["diff_len"]
624+
new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]
622625
): # new candidate has a shorter diff than the previously encountered one
623-
ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
624-
ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
626+
self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
627+
self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
625628
if candidate.optimization_id.endswith("cdrp"):
626629
log_code_repair_to_db(
627630
code_repair_log_db=self.code_repair_log_db,
@@ -636,7 +639,7 @@ def determine_best_candidate(
636639
else "no",
637640
)
638641
continue
639-
ast_code_to_id[normalized_code] = {
642+
self.ast_code_to_id[normalized_code] = {
640643
"optimization_id": candidate.optimization_id,
641644
"shorter_source_code": candidate.source_code,
642645
"diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
@@ -657,6 +660,9 @@ def determine_best_candidate(
657660
speedup_ratios[candidate.optimization_id] = None
658661
else:
659662
candidate_result: OptimizedCandidateResult = run_results.unwrap()
663+
# override the candidate if the optimization_id has changed, this may happen if the candidate was modified by the code-repair
664+
if candidate.optimization_id != candidate_result.optimized_candidate.optimization_id:
665+
candidate = candidate_result.optimized_candidate
660666
best_test_runtime = candidate_result.best_test_runtime
661667
optimized_runtimes[candidate.optimization_id] = best_test_runtime
662668
is_correct[candidate.optimization_id] = True
@@ -821,7 +827,7 @@ def determine_best_candidate(
821827
for valid_opt in valid_optimizations:
822828
valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip())
823829
new_candidate_with_shorter_code = OptimizedCandidate(
824-
source_code=ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"],
830+
source_code=self.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"],
825831
optimization_id=valid_opt.candidate.optimization_id,
826832
explanation=valid_opt.candidate.explanation,
827833
)
@@ -1946,7 +1952,18 @@ def run_optimized_candidate(
19461952

19471953
code_print(new_candidate.source_code.flat)
19481954

1955+
normalized_code = normalize_code(candidate.source_code.flat.strip())
1956+
self.ast_code_to_id[normalized_code] = {
1957+
"optimization_id": candidate.optimization_id,
1958+
"shorter_source_code": candidate.source_code,
1959+
"diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
1960+
}
1961+
19491962
try:
1963+
# revert first to original code then replace with new repaired code, so we don't get any weird behavior
1964+
self.write_code_and_helpers(
1965+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1966+
)
19501967
did_update = self.replace_function_and_helpers_with_optimized_code(
19511968
code_context=code_context,
19521969
optimized_code=new_candidate.source_code,
@@ -2048,6 +2065,7 @@ def run_optimized_candidate(
20482065
)
20492066
return Success(
20502067
OptimizedCandidateResult(
2068+
optimized_candidate=candidate,
20512069
max_loop_count=loop_count,
20522070
best_test_runtime=total_candidate_timing,
20532071
behavior_test_results=candidate_behavior_results,

0 commit comments

Comments
 (0)