|
19 | 19 | from rich.tree import Tree |
20 | 20 |
|
21 | 21 | from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient |
22 | | -from codeflash.api.cfapi import add_code_context_hash, mark_optimization_success |
| 22 | +from codeflash.api.cfapi import add_code_context_hash, create_staging, mark_optimization_success |
23 | 23 | from codeflash.benchmarking.utils import process_benchmark_data |
24 | 24 | from codeflash.cli_cmds.console import code_print, console, logger, progress_bar |
25 | 25 | from codeflash.code_utils import env_utils |
@@ -1088,97 +1088,124 @@ def find_and_process_best_optimization( |
1088 | 1088 | original_code_combined[explanation.file_path] = self.function_to_optimize_source_code |
1089 | 1089 | new_code_combined = new_helper_code.copy() |
1090 | 1090 | new_code_combined[explanation.file_path] = new_code |
1091 | | - if not self.args.no_pr: |
1092 | | - coverage_message = ( |
1093 | | - original_code_baseline.coverage_results.build_message() |
1094 | | - if original_code_baseline.coverage_results |
1095 | | - else "Coverage data not available" |
1096 | | - ) |
1097 | | - generated_tests = remove_functions_from_generated_tests( |
1098 | | - generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove |
1099 | | - ) |
1100 | | - original_runtime_by_test = ( |
1101 | | - original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case() |
1102 | | - ) |
1103 | | - optimized_runtime_by_test = ( |
1104 | | - best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() |
1105 | | - ) |
1106 | | - qualified_name = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root) |
1107 | | - # Add runtime comments to generated tests before creating the PR |
1108 | | - generated_tests = add_runtime_comments_to_generated_tests( |
1109 | | - generated_tests, original_runtime_by_test, optimized_runtime_by_test |
1110 | | - ) |
1111 | | - generated_tests_str = "\n#------------------------------------------------\n".join( |
1112 | | - [test.generated_original_test_source for test in generated_tests.generated_tests] |
1113 | | - ) |
1114 | | - existing_tests = existing_tests_source_for( |
1115 | | - qualified_name, |
1116 | | - function_to_all_tests, |
1117 | | - test_cfg=self.test_cfg, |
1118 | | - original_runtimes_all=original_runtime_by_test, |
1119 | | - optimized_runtimes_all=optimized_runtime_by_test, |
1120 | | - ) |
1121 | | - if concolic_test_str: |
1122 | | - generated_tests_str += ( |
1123 | | - "\n#------------------------------------------------\n" + concolic_test_str |
1124 | | - ) |
1125 | | - new_explanation_raw_str = self.aiservice_client.get_new_explanation( |
1126 | | - source_code=code_context.read_writable_code, |
1127 | | - dependency_code=code_context.read_only_context_code, |
1128 | | - trace_id=self.function_trace_id[:-4] + exp_type |
1129 | | - if self.experiment_id |
1130 | | - else self.function_trace_id, |
1131 | | - optimized_code=best_optimization.candidate.source_code, |
1132 | | - original_line_profiler_results=original_code_baseline.line_profile_results["str_out"], |
1133 | | - optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"], |
1134 | | - original_code_runtime=humanize_runtime(original_code_baseline.runtime), |
1135 | | - optimized_code_runtime=humanize_runtime(best_optimization.runtime), |
1136 | | - speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%", |
1137 | | - annotated_tests=generated_tests_str, |
1138 | | - optimization_id=best_optimization.candidate.optimization_id, |
1139 | | - original_explanation=best_optimization.candidate.explanation, |
1140 | | - ) |
1141 | | - new_explanation = Explanation( |
1142 | | - raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message, |
1143 | | - winning_behavior_test_results=explanation.winning_behavior_test_results, |
1144 | | - winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, |
1145 | | - original_runtime_ns=explanation.original_runtime_ns, |
1146 | | - best_runtime_ns=explanation.best_runtime_ns, |
1147 | | - function_name=explanation.function_name, |
1148 | | - file_path=explanation.file_path, |
1149 | | - benchmark_details=explanation.benchmark_details, |
1150 | | - ) |
1151 | | - check_create_pr( |
1152 | | - original_code=original_code_combined, |
1153 | | - new_code=new_code_combined, |
1154 | | - explanation=new_explanation, |
1155 | | - existing_tests_source=existing_tests, |
1156 | | - generated_original_test_source=generated_tests_str, |
1157 | | - function_trace_id=self.function_trace_id[:-4] + exp_type |
1158 | | - if self.experiment_id |
1159 | | - else self.function_trace_id, |
1160 | | - coverage_message=coverage_message, |
1161 | | - git_remote=self.args.git_remote, |
1162 | | - ) |
1163 | | - if ( |
1164 | | - self.args.all |
1165 | | - or env_utils.get_pr_number() |
1166 | | - or self.args.replay_test |
1167 | | - or (self.args.file and not self.args.function) |
1168 | | - ): |
1169 | | - self.write_code_and_helpers( |
1170 | | - self.function_to_optimize_source_code, |
1171 | | - original_helper_code, |
1172 | | - self.function_to_optimize.file_path, |
1173 | | - ) |
1174 | | - else: |
1175 | | - # Mark optimization success since no PR will be created |
1176 | | - mark_optimization_success( |
1177 | | - trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None |
1178 | | - ) |
| 1091 | + self.process_review( |
| 1092 | + original_code_baseline, |
| 1093 | + best_optimization, |
| 1094 | + generated_tests, |
| 1095 | + test_functions_to_remove, |
| 1096 | + concolic_test_str, |
| 1097 | + original_code_combined, |
| 1098 | + new_code_combined, |
| 1099 | + explanation, |
| 1100 | + function_to_all_tests, |
| 1101 | + exp_type, |
| 1102 | + original_helper_code, |
| 1103 | + code_context, |
| 1104 | + ) |
1179 | 1105 | self.log_successful_optimization(explanation, generated_tests, exp_type) |
1180 | 1106 | return best_optimization |
1181 | 1107 |
|
| 1108 | + def process_review( |
| 1109 | + self, |
| 1110 | + original_code_baseline: OriginalCodeBaseline, |
| 1111 | + best_optimization: BestOptimization, |
| 1112 | + generated_tests: GeneratedTestsList, |
| 1113 | + test_functions_to_remove: list[str], |
| 1114 | + concolic_test_str: str | None, |
| 1115 | + original_code_combined: dict[Path, str], |
| 1116 | + new_code_combined: dict[Path, str], |
| 1117 | + explanation: Explanation, |
| 1118 | + function_to_all_tests: dict[str, set[FunctionCalledInTest]], |
| 1119 | + exp_type: str, |
| 1120 | + original_helper_code: dict[Path, str], |
| 1121 | + code_context: CodeOptimizationContext, |
| 1122 | + ) -> None: |
| 1123 | + coverage_message = ( |
| 1124 | + original_code_baseline.coverage_results.build_message() |
| 1125 | + if original_code_baseline.coverage_results |
| 1126 | + else "Coverage data not available" |
| 1127 | + ) |
| 1128 | + |
| 1129 | + generated_tests = remove_functions_from_generated_tests( |
| 1130 | + generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove |
| 1131 | + ) |
| 1132 | + |
| 1133 | + original_runtime_by_test = original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case() |
| 1134 | + optimized_runtime_by_test = ( |
| 1135 | + best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() |
| 1136 | + ) |
| 1137 | + |
| 1138 | + generated_tests = add_runtime_comments_to_generated_tests( |
| 1139 | + generated_tests, original_runtime_by_test, optimized_runtime_by_test |
| 1140 | + ) |
| 1141 | + |
| 1142 | + generated_tests_str = "\n#------------------------------------------------\n".join( |
| 1143 | + [test.generated_original_test_source for test in generated_tests.generated_tests] |
| 1144 | + ) |
| 1145 | + if concolic_test_str: |
| 1146 | + generated_tests_str += "\n#------------------------------------------------\n" + concolic_test_str |
| 1147 | + |
| 1148 | + existing_tests = existing_tests_source_for( |
| 1149 | + self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), |
| 1150 | + function_to_all_tests, |
| 1151 | + test_cfg=self.test_cfg, |
| 1152 | + original_runtimes_all=original_runtime_by_test, |
| 1153 | + optimized_runtimes_all=optimized_runtime_by_test, |
| 1154 | + ) |
| 1155 | + new_explanation_raw_str = self.aiservice_client.get_new_explanation( |
| 1156 | + source_code=code_context.read_writable_code, |
| 1157 | + dependency_code=code_context.read_only_context_code, |
| 1158 | + trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, |
| 1159 | + optimized_code=best_optimization.candidate.source_code, |
| 1160 | + original_line_profiler_results=original_code_baseline.line_profile_results["str_out"], |
| 1161 | + optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"], |
| 1162 | + original_code_runtime=humanize_runtime(original_code_baseline.runtime), |
| 1163 | + optimized_code_runtime=humanize_runtime(best_optimization.runtime), |
| 1164 | + speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%", |
| 1165 | + annotated_tests=generated_tests_str, |
| 1166 | + optimization_id=best_optimization.candidate.optimization_id, |
| 1167 | + original_explanation=best_optimization.candidate.explanation, |
| 1168 | + ) |
| 1169 | + new_explanation = Explanation( |
| 1170 | + raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message, |
| 1171 | + winning_behavior_test_results=explanation.winning_behavior_test_results, |
| 1172 | + winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, |
| 1173 | + original_runtime_ns=explanation.original_runtime_ns, |
| 1174 | + best_runtime_ns=explanation.best_runtime_ns, |
| 1175 | + function_name=explanation.function_name, |
| 1176 | + file_path=explanation.file_path, |
| 1177 | + benchmark_details=explanation.benchmark_details, |
| 1178 | + ) |
| 1179 | + data = { |
| 1180 | + "original_code": original_code_combined, |
| 1181 | + "new_code": new_code_combined, |
| 1182 | + "explanation": new_explanation, |
| 1183 | + "existing_tests_source": existing_tests, |
| 1184 | + "generated_original_test_source": generated_tests_str, |
| 1185 | + "function_trace_id": self.function_trace_id[:-4] + exp_type |
| 1186 | + if self.experiment_id |
| 1187 | + else self.function_trace_id, |
| 1188 | + "coverage_message": coverage_message, |
| 1189 | + } |
| 1190 | + |
| 1191 | + if not self.args.no_pr and not self.args.staging_review: |
| 1192 | + data["git_remote"] = self.args.git_remote |
| 1193 | + check_create_pr(**data) |
| 1194 | + elif self.args.staging_review: |
| 1195 | + create_staging(**data) |
| 1196 | + else: |
| 1197 | + # Mark optimization success since no PR will be created |
| 1198 | + mark_optimization_success( |
| 1199 | + trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None |
| 1200 | + ) |
| 1201 | + |
| 1202 | + if ((not self.args.no_pr) or not self.args.staging_review) and ( |
| 1203 | + self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function) |
| 1204 | + ): |
| 1205 | + self.write_code_and_helpers( |
| 1206 | + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path |
| 1207 | + ) |
| 1208 | + |
1182 | 1209 | def establish_original_code_baseline( |
1183 | 1210 | self, |
1184 | 1211 | code_context: CodeOptimizationContext, |
|
0 commit comments