Skip to content

Commit ecb10ab

Browse files
committed
todo cleanup
1 parent 0d93128 commit ecb10ab

File tree

1 file changed

+33
-4
lines changed

1 file changed

+33
-4
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
TestingMode,
7575
TestResults,
7676
TestType,
77+
OptimizedCandidate
7778
)
7879
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
7980
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
@@ -380,6 +381,7 @@ def determine_best_candidate(
380381
candidates = deque(candidates)
381382
refinement_done = False
382383
future_all_refinements: list[concurrent.futures.Future] = []
384+
ast_code_to_id = dict()
383385
# Start a new thread for AI service request, start loop in main thread
384386
# check if aiservice request is complete, when it is complete, append result to the candidates list
385387
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
@@ -413,6 +415,8 @@ def determine_best_candidate(
413415
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
414416
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
415417
code_print(candidate.source_code)
418+
# map ast normalized code to diff len, unnormalized code
419+
# map opt id to the shortest unnormalized code
416420
try:
417421
did_update = self.replace_function_and_helpers_with_optimized_code(
418422
code_context=code_context,
@@ -431,7 +435,15 @@ def determine_best_candidate(
431435
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
432436
)
433437
continue
434-
438+
normalized_code = ast.unparse(ast.parse(candidate.source_code.strip()))
439+
if normalized_code in ast_code_to_id:
440+
new_diff_len = diff_length(candidate.source_code, code_context.read_writable_code)
441+
if new_diff_len < ast_code_to_id[normalized_code]["diff_len"]:
442+
ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
443+
ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
444+
continue
445+
else:
446+
ast_code_to_id[normalized_code] = {'optimization_id':candidate.optimization_id, 'shorter_source_code':candidate.source_code, 'diff_len':diff_length(candidate.source_code, code_context.read_writable_code)}
435447
run_results = self.run_optimized_candidate(
436448
optimization_candidate_index=candidate_index,
437449
baseline_results=original_code_baseline,
@@ -569,19 +581,36 @@ def determine_best_candidate(
569581
if not len(self.valid_optimizations):
570582
return None
571583
# need to figure out the best candidate here before we return best_optimization
584+
#reassign the shorter code here
585+
valid_candidates_with_shorter_code = []
572586
diff_lens_list = [] # character level diff
573587
runtimes_list = []
574588
for valid_opt in self.valid_optimizations:
589+
valid_opt_normalized_code = ast.unparse(ast.parse(valid_opt.candidate.source_code.strip()))
590+
new_candidate_with_shorter_code = OptimizedCandidate(source_code=ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], optimization_id=valid_opt.candidate.optimization_id, explanation=valid_opt.candidate.explanation)
591+
new_best_opt = BestOptimization(
592+
candidate=new_candidate_with_shorter_code,
593+
helper_functions=valid_opt.helper_functions,
594+
code_context=valid_opt.code_context,
595+
runtime=valid_opt.runtime,
596+
line_profiler_test_results=valid_opt.line_profiler_test_results,
597+
winning_behavior_test_results=valid_opt.winning_behavior_test_results,
598+
replay_performance_gain=valid_opt.replay_performance_gain,
599+
winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results,
600+
winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results,
601+
)
602+
valid_candidates_with_shorter_code.append(new_best_opt)
575603
diff_lens_list.append(
576-
diff_length(valid_opt.candidate.source_code, code_context.read_writable_code)
604+
diff_length(new_best_opt.candidate.source_code, code_context.read_writable_code)
577605
) # char level diff
578-
runtimes_list.append(valid_opt.runtime)
606+
runtimes_list.append(new_best_opt.runtime)
579607
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
580608
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
581609
# TODO: better way to resolve conflicts with same min ranking
582610
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking.keys()} # noqa: SIM118
583611
min_key = min(overall_ranking, key=overall_ranking.get)
584-
best_optimization = self.valid_optimizations[min_key]
612+
best_optimization = valid_candidates_with_shorter_code[min_key]
613+
#reassign code string which is the shortest
585614
ai_service_client.log_results(
586615
function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
587616
speedup_ratio=speedup_ratios,

0 commit comments

Comments
 (0)