|
94 | 94 |
|
95 | 95 | from codeflash.discovery.functions_to_optimize import FunctionToOptimize |
96 | 96 | from codeflash.either import Result |
97 | | - from codeflash.models.models import BenchmarkKey, CoverageData, FunctionCalledInTest, FunctionSource |
98 | 97 | from codeflash.models.models import ( |
99 | 98 | BenchmarkKey, |
100 | 99 | CodeStringsMarkdown, |
101 | 100 | CoverageData, |
102 | 101 | FunctionCalledInTest, |
103 | 102 | FunctionSource, |
104 | | - OptimizedCandidate, |
105 | 103 | ) |
106 | 104 | from codeflash.verification.verification_utils import TestConfig |
107 | 105 |
|
@@ -385,6 +383,7 @@ def determine_best_candidate( |
385 | 383 | future_all_refinements: list[concurrent.futures.Future] = [] |
386 | 384 | ast_code_to_id = {} |
387 | 385 | valid_optimizations = [] |
| 386 | + optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated |
388 | 387 | # Start a new thread for AI service request, start loop in main thread |
389 | 388 | # check if aiservice request is complete, when it is complete, append result to the candidates list |
390 | 389 | ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client |
@@ -438,17 +437,37 @@ def determine_best_candidate( |
438 | 437 | self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path |
439 | 438 | ) |
440 | 439 | continue |
441 | | - normalized_code = ast.unparse(ast.parse(candidate.source_code.strip())) |
| 440 | + # check if this code has been evaluated before by checking the ast normalized code string |
| 441 | + normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip())) |
442 | 442 | if normalized_code in ast_code_to_id: |
443 | | - new_diff_len = diff_length(candidate.source_code, code_context.read_writable_code) |
| 443 | + # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes |
| 444 | + speedup_ratios[candidate.optimization_id] = speedup_ratios[ |
| 445 | + ast_code_to_id[normalized_code]["optimization_id"] |
| 446 | + ] |
| 447 | + is_correct[candidate.optimization_id] = is_correct[ |
| 448 | + ast_code_to_id[normalized_code]["optimization_id"] |
| 449 | + ] |
| 450 | + optimized_runtimes[candidate.optimization_id] = optimized_runtimes[ |
| 451 | + ast_code_to_id[normalized_code]["optimization_id"] |
| 452 | + ] |
| 453 | + optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[ |
| 454 | + ast_code_to_id[normalized_code]["optimization_id"] |
| 455 | + ] |
| 456 | + optimizations_post[candidate.optimization_id] = ast_code_to_id[normalized_code][ |
| 457 | + "shorter_source_code" |
| 458 | + ].markdown |
| 459 | + optimizations_post[ast_code_to_id[normalized_code]["optimization_id"]] = ast_code_to_id[ |
| 460 | + normalized_code |
| 461 | + ]["shorter_source_code"].markdown |
| 462 | + new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) |
444 | 463 | if new_diff_len < ast_code_to_id[normalized_code]["diff_len"]: |
445 | 464 | ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code |
446 | 465 | ast_code_to_id[normalized_code]["diff_len"] = new_diff_len |
447 | 466 | continue |
448 | 467 | ast_code_to_id[normalized_code] = { |
449 | 468 | "optimization_id": candidate.optimization_id, |
450 | 469 | "shorter_source_code": candidate.source_code, |
451 | | - "diff_len": diff_length(candidate.source_code, code_context.read_writable_code), |
| 470 | + "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), |
452 | 471 | } |
453 | 472 | run_results = self.run_optimized_candidate( |
454 | 473 | optimization_candidate_index=candidate_index, |
@@ -592,7 +611,7 @@ def determine_best_candidate( |
592 | 611 | diff_lens_list = [] # character level diff |
593 | 612 | runtimes_list = [] |
594 | 613 | for valid_opt in valid_optimizations: |
595 | | - valid_opt_normalized_code = ast.unparse(ast.parse(valid_opt.candidate.source_code.strip())) |
| 614 | + valid_opt_normalized_code = ast.unparse(ast.parse(valid_opt.candidate.source_code.flat.strip())) |
596 | 615 | new_candidate_with_shorter_code = OptimizedCandidate( |
597 | 616 | source_code=ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], |
598 | 617 | optimization_id=valid_opt.candidate.optimization_id, |
@@ -628,6 +647,7 @@ def determine_best_candidate( |
628 | 647 | optimized_runtime=optimized_runtimes, |
629 | 648 | is_correct=is_correct, |
630 | 649 | optimized_line_profiler_results=optimized_line_profiler_results, |
| 650 | + optimizations_post=optimizations_post, |
631 | 651 | metadata={"best_optimization_id": best_optimization.candidate.optimization_id}, |
632 | 652 | ) |
633 | 653 | return best_optimization |
|
0 commit comments