|
19 | 19 | from rich.tree import Tree |
20 | 20 |
|
21 | 21 | from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient |
22 | | -from codeflash.api.cfapi import add_code_context_hash, mark_optimization_success |
| 22 | +from codeflash.api.cfapi import add_code_context_hash, create_staging, mark_optimization_success |
23 | 23 | from codeflash.benchmarking.utils import process_benchmark_data |
24 | 24 | from codeflash.cli_cmds.console import code_print, console, logger, progress_bar |
25 | 25 | from codeflash.code_utils import env_utils |
@@ -1088,70 +1088,99 @@ def find_and_process_best_optimization( |
1088 | 1088 | original_code_combined[explanation.file_path] = self.function_to_optimize_source_code |
1089 | 1089 | new_code_combined = new_helper_code.copy() |
1090 | 1090 | new_code_combined[explanation.file_path] = new_code |
1091 | | - if not self.args.no_pr: |
1092 | | - coverage_message = ( |
1093 | | - original_code_baseline.coverage_results.build_message() |
1094 | | - if original_code_baseline.coverage_results |
1095 | | - else "Coverage data not available" |
1096 | | - ) |
1097 | | - generated_tests = remove_functions_from_generated_tests( |
1098 | | - generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove |
1099 | | - ) |
1100 | | - original_runtime_by_test = ( |
1101 | | - original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case() |
1102 | | - ) |
1103 | | - optimized_runtime_by_test = ( |
1104 | | - best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() |
1105 | | - ) |
1106 | | - qualified_name = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root) |
1107 | | - # Add runtime comments to generated tests before creating the PR |
1108 | | - generated_tests = add_runtime_comments_to_generated_tests( |
1109 | | - generated_tests, original_runtime_by_test, optimized_runtime_by_test |
1110 | | - ) |
1111 | | - generated_tests_str = "\n\n".join( |
1112 | | - [test.generated_original_test_source for test in generated_tests.generated_tests] |
1113 | | - ) |
1114 | | - existing_tests = existing_tests_source_for( |
1115 | | - qualified_name, |
1116 | | - function_to_all_tests, |
1117 | | - test_cfg=self.test_cfg, |
1118 | | - original_runtimes_all=original_runtime_by_test, |
1119 | | - optimized_runtimes_all=optimized_runtime_by_test, |
1120 | | - ) |
1121 | | - if concolic_test_str: |
1122 | | - generated_tests_str += "\n\n" + concolic_test_str |
1123 | | - |
1124 | | - check_create_pr( |
1125 | | - original_code=original_code_combined, |
1126 | | - new_code=new_code_combined, |
1127 | | - explanation=explanation, |
1128 | | - existing_tests_source=existing_tests, |
1129 | | - generated_original_test_source=generated_tests_str, |
1130 | | - function_trace_id=self.function_trace_id[:-4] + exp_type |
1131 | | - if self.experiment_id |
1132 | | - else self.function_trace_id, |
1133 | | - coverage_message=coverage_message, |
1134 | | - git_remote=self.args.git_remote, |
1135 | | - ) |
1136 | | - if ( |
1137 | | - self.args.all |
1138 | | - or env_utils.get_pr_number() |
1139 | | - or self.args.replay_test |
1140 | | - or (self.args.file and not self.args.function) |
1141 | | - ): |
1142 | | - self.write_code_and_helpers( |
1143 | | - self.function_to_optimize_source_code, |
1144 | | - original_helper_code, |
1145 | | - self.function_to_optimize.file_path, |
1146 | | - ) |
1147 | | - else: |
1148 | | - # Mark optimization success since no PR will be created |
1149 | | - mark_optimization_success( |
1150 | | - trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None |
1151 | | - ) |
| 1091 | + self.process_review( |
| 1092 | + original_code_baseline, |
| 1093 | + best_optimization, |
| 1094 | + generated_tests, |
| 1095 | + test_functions_to_remove, |
| 1096 | + concolic_test_str, |
| 1097 | + original_code_combined, |
| 1098 | + new_code_combined, |
| 1099 | + explanation, |
| 1100 | + function_to_all_tests, |
| 1101 | + exp_type, |
| 1102 | + original_helper_code, |
| 1103 | + ) |
1152 | 1104 | self.log_successful_optimization(explanation, generated_tests, exp_type) |
1153 | 1105 | return best_optimization |
1154 | 1106 |
|
| 1107 | + def process_review( |
| 1108 | + self, |
| 1109 | + original_code_baseline: OriginalCodeBaseline, |
| 1110 | + best_optimization: BestOptimization, |
| 1111 | + generated_tests: GeneratedTestsList, |
| 1112 | + test_functions_to_remove: list[str], |
| 1113 | + concolic_test_str: str | None, |
| 1114 | + original_code_combined: dict[Path, str], |
| 1115 | + new_code_combined: dict[Path, str], |
| 1116 | + explanation: Explanation, |
| 1117 | + function_to_all_tests: dict[str, set[FunctionCalledInTest]], |
| 1118 | + exp_type: str, |
| 1119 | + original_helper_code: dict[Path, str], |
| 1120 | + ) -> None: |
| 1121 | + coverage_message = ( |
| 1122 | + original_code_baseline.coverage_results.build_message() |
| 1123 | + if original_code_baseline.coverage_results |
| 1124 | + else "Coverage data not available" |
| 1125 | + ) |
| 1126 | + |
| 1127 | + generated_tests = remove_functions_from_generated_tests( |
| 1128 | + generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove |
| 1129 | + ) |
| 1130 | + |
| 1131 | + original_runtime_by_test = original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case() |
| 1132 | + optimized_runtime_by_test = ( |
| 1133 | + best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() |
| 1134 | + ) |
| 1135 | + |
| 1136 | + generated_tests = add_runtime_comments_to_generated_tests( |
| 1137 | + generated_tests, original_runtime_by_test, optimized_runtime_by_test |
| 1138 | + ) |
| 1139 | + |
| 1140 | + generated_tests_str = "\n\n".join( |
| 1141 | + [test.generated_original_test_source for test in generated_tests.generated_tests] |
| 1142 | + ) |
| 1143 | + if concolic_test_str: |
| 1144 | + generated_tests_str += "\n\n" + concolic_test_str |
| 1145 | + |
| 1146 | + existing_tests = existing_tests_source_for( |
| 1147 | + self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), |
| 1148 | + function_to_all_tests, |
| 1149 | + test_cfg=self.test_cfg, |
| 1150 | + original_runtimes_all=original_runtime_by_test, |
| 1151 | + optimized_runtimes_all=optimized_runtime_by_test, |
| 1152 | + ) |
| 1153 | + |
| 1154 | + data = { |
| 1155 | + "original_code": original_code_combined, |
| 1156 | + "new_code": new_code_combined, |
| 1157 | + "explanation": explanation, |
| 1158 | + "existing_tests_source": existing_tests, |
| 1159 | + "generated_original_test_source": generated_tests_str, |
| 1160 | + "function_trace_id": self.function_trace_id[:-4] + exp_type |
| 1161 | + if self.experiment_id |
| 1162 | + else self.function_trace_id, |
| 1163 | + "coverage_message": coverage_message, |
| 1164 | + } |
| 1165 | + |
| 1166 | + if not self.args.no_pr and not self.args.staging_review: |
| 1167 | + data["git_remote"] = self.args.git_remote |
| 1168 | + check_create_pr(**data) |
| 1169 | + elif self.args.staging_review: |
| 1170 | + create_staging(**data) |
| 1171 | + else: |
| 1172 | + # Mark optimization success since no PR will be created |
| 1173 | + mark_optimization_success( |
| 1174 | + trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None |
| 1175 | + ) |
| 1176 | + |
| 1177 | + if ((not self.args.no_pr) or not self.args.staging_review) and ( |
| 1178 | + self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function) |
| 1179 | + ): |
| 1180 | + self.write_code_and_helpers( |
| 1181 | + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path |
| 1182 | + ) |
| 1183 | + |
1155 | 1184 | def establish_original_code_baseline( |
1156 | 1185 | self, |
1157 | 1186 | code_context: CodeOptimizationContext, |
|
0 commit comments