|
19 | 19 | from rich.tree import Tree |
20 | 20 |
|
21 | 21 | from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient |
22 | | -from codeflash.api.cfapi import add_code_context_hash, mark_optimization_success |
| 22 | +from codeflash.api.cfapi import add_code_context_hash, mark_optimization_success, create_staging |
23 | 23 | from codeflash.benchmarking.utils import process_benchmark_data |
24 | 24 | from codeflash.cli_cmds.console import code_print, console, logger, progress_bar |
25 | 25 | from codeflash.code_utils import env_utils |
@@ -997,64 +997,76 @@ def find_and_process_best_optimization( |
997 | 997 | original_code_combined[explanation.file_path] = self.function_to_optimize_source_code |
998 | 998 | new_code_combined = new_helper_code.copy() |
999 | 999 | new_code_combined[explanation.file_path] = new_code |
1000 | | - if not self.args.no_pr: |
1001 | | - coverage_message = ( |
1002 | | - original_code_baseline.coverage_results.build_message() |
1003 | | - if original_code_baseline.coverage_results |
1004 | | - else "Coverage data not available" |
1005 | | - ) |
1006 | | - generated_tests = remove_functions_from_generated_tests( |
1007 | | - generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove |
1008 | | - ) |
1009 | | - original_runtime_by_test = ( |
1010 | | - original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case() |
1011 | | - ) |
1012 | | - optimized_runtime_by_test = ( |
1013 | | - best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() |
1014 | | - ) |
1015 | | - # Add runtime comments to generated tests before creating the PR |
1016 | | - generated_tests = add_runtime_comments_to_generated_tests( |
1017 | | - self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test |
1018 | | - ) |
1019 | | - generated_tests_str = "\n\n".join( |
1020 | | - [test.generated_original_test_source for test in generated_tests.generated_tests] |
1021 | | - ) |
1022 | | - existing_tests = existing_tests_source_for( |
1023 | | - self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), |
1024 | | - function_to_all_tests, |
1025 | | - test_cfg=self.test_cfg, |
1026 | | - original_runtimes_all=original_runtime_by_test, |
1027 | | - optimized_runtimes_all=optimized_runtime_by_test, |
1028 | | - ) |
1029 | | - if concolic_test_str: |
1030 | | - generated_tests_str += "\n\n" + concolic_test_str |
1031 | | - |
1032 | | - check_create_pr( |
1033 | | - original_code=original_code_combined, |
1034 | | - new_code=new_code_combined, |
1035 | | - explanation=explanation, |
1036 | | - existing_tests_source=existing_tests, |
1037 | | - generated_original_test_source=generated_tests_str, |
1038 | | - function_trace_id=self.function_trace_id[:-4] + exp_type |
1039 | | - if self.experiment_id |
1040 | | - else self.function_trace_id, |
1041 | | - coverage_message=coverage_message, |
1042 | | - git_remote=self.args.git_remote, |
1043 | | - ) |
1044 | | - if self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function): |
1045 | | - self.write_code_and_helpers( |
1046 | | - self.function_to_optimize_source_code, |
1047 | | - original_helper_code, |
1048 | | - self.function_to_optimize.file_path, |
1049 | | - ) |
1050 | | - else: |
1051 | | - # Mark optimization success since no PR will be created |
1052 | | - mark_optimization_success( |
1053 | | - trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None |
1054 | | - ) |
| 1000 | + self.process_review(original_code_baseline, best_optimization, generated_tests, |
| 1001 | + test_functions_to_remove, concolic_test_str, original_code_combined, new_code_combined, |
| 1002 | + explanation, function_to_all_tests, exp_type, original_helper_code) |
1055 | 1003 | self.log_successful_optimization(explanation, generated_tests, exp_type) |
1056 | 1004 | return best_optimization |
1057 | 1005 |
|
| 1006 | + def process_review( |
| 1007 | + self, original_code_baseline, best_optimization, generated_tests, |
| 1008 | + test_functions_to_remove, concolic_test_str, original_code_combined, |
| 1009 | + new_code_combined, explanation, function_to_all_tests, exp_type, original_helper_code |
| 1010 | + ): |
| 1011 | + coverage_message = ( |
| 1012 | + original_code_baseline.coverage_results.build_message() |
| 1013 | + if original_code_baseline.coverage_results else "Coverage data not available" |
| 1014 | + ) |
| 1015 | + |
| 1016 | + generated_tests = remove_functions_from_generated_tests( |
| 1017 | + generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove |
| 1018 | + ) |
| 1019 | + |
| 1020 | + original_runtime_by_test = original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case() |
| 1021 | + optimized_runtime_by_test = best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() |
| 1022 | + |
| 1023 | + generated_tests = add_runtime_comments_to_generated_tests( |
| 1024 | + self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test |
| 1025 | + ) |
| 1026 | + |
| 1027 | + generated_tests_str = "\n\n".join( |
| 1028 | + [test.generated_original_test_source for test in generated_tests.generated_tests] |
| 1029 | + ) |
| 1030 | + if concolic_test_str: |
| 1031 | + generated_tests_str += "\n\n" + concolic_test_str |
| 1032 | + |
| 1033 | + existing_tests = existing_tests_source_for( |
| 1034 | + self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), |
| 1035 | + function_to_all_tests, |
| 1036 | + test_cfg=self.test_cfg, |
| 1037 | + original_runtimes_all=original_runtime_by_test, |
| 1038 | + optimized_runtimes_all=optimized_runtime_by_test, |
| 1039 | + ) |
| 1040 | + |
| 1041 | + data = dict( |
| 1042 | + original_code=original_code_combined, |
| 1043 | + new_code=new_code_combined, |
| 1044 | + explanation=explanation, |
| 1045 | + existing_tests_source=existing_tests, |
| 1046 | + generated_original_test_source=generated_tests_str, |
| 1047 | + function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, |
| 1048 | + coverage_message=coverage_message, |
| 1049 | + ) |
| 1050 | + |
| 1051 | + if not self.args.no_pr and not self.args.staging_review: |
| 1052 | + data["git_remote"] = self.args.git_remote |
| 1053 | + check_create_pr(**data) |
| 1054 | + elif self.args.staging_review: |
| 1055 | + create_staging(**data) |
| 1056 | + else: |
| 1057 | + # Mark optimization success since no PR will be created |
| 1058 | + mark_optimization_success( |
| 1059 | + trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None |
| 1060 | + ) |
| 1061 | + |
| 1062 | + if (not self.args.no_pr) or not self.args.staging_review: |
| 1063 | + if self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function): |
| 1064 | + self.write_code_and_helpers( |
| 1065 | + self.function_to_optimize_source_code, |
| 1066 | + original_helper_code, |
| 1067 | + self.function_to_optimize.file_path, |
| 1068 | + ) |
| 1069 | + |
1058 | 1070 | def establish_original_code_baseline( |
1059 | 1071 | self, |
1060 | 1072 | code_context: CodeOptimizationContext, |
|
0 commit comments