Skip to content

Commit 26d7f63

Browse files
committed
merged changes from main
2 parents 388890c + 2e46088 commit 26d7f63

File tree

3 files changed

+174
-91
lines changed

3 files changed

+174
-91
lines changed

codeflash/api/cfapi.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
from codeflash.cli_cmds.console import console, logger
1616
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
17-
from codeflash.code_utils.git_utils import get_repo_owner_and_name
17+
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name, git_root_dir
18+
from codeflash.github.PrComment import FileDiffContent, PrComment
1819
from codeflash.version import __version__
1920

2021
if TYPE_CHECKING:
2122
from requests import Response
2223

23-
from codeflash.github.PrComment import FileDiffContent, PrComment
24+
from codeflash.result.explanation import Explanation
25+
2426
from packaging import version
2527

2628
if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local":
@@ -182,6 +184,59 @@ def create_pr(
182184
return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload)
183185

184186

187+
def create_staging(
188+
original_code: dict[Path, str],
189+
new_code: dict[Path, str],
190+
explanation: Explanation,
191+
existing_tests_source: str,
192+
generated_original_test_source: str,
193+
function_trace_id: str,
194+
coverage_message: str,
195+
) -> Response:
196+
"""Create a staging pull request, targeting the specified branch. (usually 'staging').
197+
198+
:param original_code: A mapping of file paths to original source code.
199+
:param new_code: A mapping of file paths to optimized source code.
200+
:param explanation: An Explanation object with optimization details.
201+
:param existing_tests_source: Existing test code.
202+
:param generated_original_test_source: Generated tests for the original function.
203+
:param function_trace_id: Unique identifier for this optimization trace.
204+
:param coverage_message: Coverage report or summary.
205+
:return: The response object from the backend.
206+
"""
207+
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
208+
209+
build_file_changes = {
210+
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
211+
oldContent=original_code[p], newContent=new_code[p]
212+
)
213+
for p in original_code
214+
}
215+
216+
payload = {
217+
"baseBranch": get_current_branch(),
218+
"diffContents": build_file_changes,
219+
"prCommentFields": PrComment(
220+
optimization_explanation=explanation.explanation_message(),
221+
best_runtime=explanation.best_runtime_ns,
222+
original_runtime=explanation.original_runtime_ns,
223+
function_name=explanation.function_name,
224+
relative_file_path=relative_path,
225+
speedup_x=explanation.speedup_x,
226+
speedup_pct=explanation.speedup_pct,
227+
winning_behavior_test_results=explanation.winning_behavior_test_results,
228+
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
229+
benchmark_details=explanation.benchmark_details,
230+
).to_json(),
231+
"existingTests": existing_tests_source,
232+
"generatedTests": generated_original_test_source,
233+
"traceId": function_trace_id,
234+
"coverage_message": coverage_message,
235+
}
236+
237+
return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload)
238+
239+
185240
def is_github_app_installed_on_repo(owner: str, repo: str, *, suppress_errors: bool = False) -> bool:
186241
"""Check if the Codeflash GitHub App is installed on the specified repository.
187242

codeflash/cli_cmds/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def parse_args() -> Namespace:
7777
parser.add_argument(
7878
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
7979
)
80+
parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review")
8081
parser.add_argument(
8182
"--verify-setup",
8283
action="store_true",

codeflash/optimization/function_optimizer.py

Lines changed: 116 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from rich.tree import Tree
2020

2121
from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient
22-
from codeflash.api.cfapi import add_code_context_hash, mark_optimization_success
22+
from codeflash.api.cfapi import add_code_context_hash, create_staging, mark_optimization_success
2323
from codeflash.benchmarking.utils import process_benchmark_data
2424
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
2525
from codeflash.code_utils import env_utils
@@ -1088,97 +1088,124 @@ def find_and_process_best_optimization(
10881088
original_code_combined[explanation.file_path] = self.function_to_optimize_source_code
10891089
new_code_combined = new_helper_code.copy()
10901090
new_code_combined[explanation.file_path] = new_code
1091-
if not self.args.no_pr:
1092-
coverage_message = (
1093-
original_code_baseline.coverage_results.build_message()
1094-
if original_code_baseline.coverage_results
1095-
else "Coverage data not available"
1096-
)
1097-
generated_tests = remove_functions_from_generated_tests(
1098-
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
1099-
)
1100-
original_runtime_by_test = (
1101-
original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
1102-
)
1103-
optimized_runtime_by_test = (
1104-
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
1105-
)
1106-
qualified_name = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
1107-
# Add runtime comments to generated tests before creating the PR
1108-
generated_tests = add_runtime_comments_to_generated_tests(
1109-
generated_tests, original_runtime_by_test, optimized_runtime_by_test
1110-
)
1111-
generated_tests_str = "\n#------------------------------------------------\n".join(
1112-
[test.generated_original_test_source for test in generated_tests.generated_tests]
1113-
)
1114-
existing_tests = existing_tests_source_for(
1115-
qualified_name,
1116-
function_to_all_tests,
1117-
test_cfg=self.test_cfg,
1118-
original_runtimes_all=original_runtime_by_test,
1119-
optimized_runtimes_all=optimized_runtime_by_test,
1120-
)
1121-
if concolic_test_str:
1122-
generated_tests_str += (
1123-
"\n#------------------------------------------------\n" + concolic_test_str
1124-
)
1125-
new_explanation_raw_str = self.aiservice_client.get_new_explanation(
1126-
source_code=code_context.read_writable_code,
1127-
dependency_code=code_context.read_only_context_code,
1128-
trace_id=self.function_trace_id[:-4] + exp_type
1129-
if self.experiment_id
1130-
else self.function_trace_id,
1131-
optimized_code=best_optimization.candidate.source_code,
1132-
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
1133-
optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"],
1134-
original_code_runtime=humanize_runtime(original_code_baseline.runtime),
1135-
optimized_code_runtime=humanize_runtime(best_optimization.runtime),
1136-
speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%",
1137-
annotated_tests=generated_tests_str,
1138-
optimization_id=best_optimization.candidate.optimization_id,
1139-
original_explanation=best_optimization.candidate.explanation,
1140-
)
1141-
new_explanation = Explanation(
1142-
raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message,
1143-
winning_behavior_test_results=explanation.winning_behavior_test_results,
1144-
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
1145-
original_runtime_ns=explanation.original_runtime_ns,
1146-
best_runtime_ns=explanation.best_runtime_ns,
1147-
function_name=explanation.function_name,
1148-
file_path=explanation.file_path,
1149-
benchmark_details=explanation.benchmark_details,
1150-
)
1151-
check_create_pr(
1152-
original_code=original_code_combined,
1153-
new_code=new_code_combined,
1154-
explanation=new_explanation,
1155-
existing_tests_source=existing_tests,
1156-
generated_original_test_source=generated_tests_str,
1157-
function_trace_id=self.function_trace_id[:-4] + exp_type
1158-
if self.experiment_id
1159-
else self.function_trace_id,
1160-
coverage_message=coverage_message,
1161-
git_remote=self.args.git_remote,
1162-
)
1163-
if (
1164-
self.args.all
1165-
or env_utils.get_pr_number()
1166-
or self.args.replay_test
1167-
or (self.args.file and not self.args.function)
1168-
):
1169-
self.write_code_and_helpers(
1170-
self.function_to_optimize_source_code,
1171-
original_helper_code,
1172-
self.function_to_optimize.file_path,
1173-
)
1174-
else:
1175-
# Mark optimization success since no PR will be created
1176-
mark_optimization_success(
1177-
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
1178-
)
1091+
self.process_review(
1092+
original_code_baseline,
1093+
best_optimization,
1094+
generated_tests,
1095+
test_functions_to_remove,
1096+
concolic_test_str,
1097+
original_code_combined,
1098+
new_code_combined,
1099+
explanation,
1100+
function_to_all_tests,
1101+
exp_type,
1102+
original_helper_code,
1103+
code_context,
1104+
)
11791105
self.log_successful_optimization(explanation, generated_tests, exp_type)
11801106
return best_optimization
11811107

1108+
def process_review(
1109+
self,
1110+
original_code_baseline: OriginalCodeBaseline,
1111+
best_optimization: BestOptimization,
1112+
generated_tests: GeneratedTestsList,
1113+
test_functions_to_remove: list[str],
1114+
concolic_test_str: str | None,
1115+
original_code_combined: dict[Path, str],
1116+
new_code_combined: dict[Path, str],
1117+
explanation: Explanation,
1118+
function_to_all_tests: dict[str, set[FunctionCalledInTest]],
1119+
exp_type: str,
1120+
original_helper_code: dict[Path, str],
1121+
code_context: CodeOptimizationContext,
1122+
) -> None:
1123+
coverage_message = (
1124+
original_code_baseline.coverage_results.build_message()
1125+
if original_code_baseline.coverage_results
1126+
else "Coverage data not available"
1127+
)
1128+
1129+
generated_tests = remove_functions_from_generated_tests(
1130+
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
1131+
)
1132+
1133+
original_runtime_by_test = original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
1134+
optimized_runtime_by_test = (
1135+
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
1136+
)
1137+
1138+
generated_tests = add_runtime_comments_to_generated_tests(
1139+
generated_tests, original_runtime_by_test, optimized_runtime_by_test
1140+
)
1141+
1142+
generated_tests_str = "\n#------------------------------------------------\n".join(
1143+
[test.generated_original_test_source for test in generated_tests.generated_tests]
1144+
)
1145+
if concolic_test_str:
1146+
generated_tests_str += "\n#------------------------------------------------\n" + concolic_test_str
1147+
1148+
existing_tests = existing_tests_source_for(
1149+
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
1150+
function_to_all_tests,
1151+
test_cfg=self.test_cfg,
1152+
original_runtimes_all=original_runtime_by_test,
1153+
optimized_runtimes_all=optimized_runtime_by_test,
1154+
)
1155+
new_explanation_raw_str = self.aiservice_client.get_new_explanation(
1156+
source_code=code_context.read_writable_code,
1157+
dependency_code=code_context.read_only_context_code,
1158+
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
1159+
optimized_code=best_optimization.candidate.source_code,
1160+
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
1161+
optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"],
1162+
original_code_runtime=humanize_runtime(original_code_baseline.runtime),
1163+
optimized_code_runtime=humanize_runtime(best_optimization.runtime),
1164+
speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%",
1165+
annotated_tests=generated_tests_str,
1166+
optimization_id=best_optimization.candidate.optimization_id,
1167+
original_explanation=best_optimization.candidate.explanation,
1168+
)
1169+
new_explanation = Explanation(
1170+
raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message,
1171+
winning_behavior_test_results=explanation.winning_behavior_test_results,
1172+
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
1173+
original_runtime_ns=explanation.original_runtime_ns,
1174+
best_runtime_ns=explanation.best_runtime_ns,
1175+
function_name=explanation.function_name,
1176+
file_path=explanation.file_path,
1177+
benchmark_details=explanation.benchmark_details,
1178+
)
1179+
data = {
1180+
"original_code": original_code_combined,
1181+
"new_code": new_code_combined,
1182+
"explanation": new_explanation,
1183+
"existing_tests_source": existing_tests,
1184+
"generated_original_test_source": generated_tests_str,
1185+
"function_trace_id": self.function_trace_id[:-4] + exp_type
1186+
if self.experiment_id
1187+
else self.function_trace_id,
1188+
"coverage_message": coverage_message,
1189+
}
1190+
1191+
if not self.args.no_pr and not self.args.staging_review:
1192+
data["git_remote"] = self.args.git_remote
1193+
check_create_pr(**data)
1194+
elif self.args.staging_review:
1195+
create_staging(**data)
1196+
else:
1197+
# Mark optimization success since no PR will be created
1198+
mark_optimization_success(
1199+
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
1200+
)
1201+
1202+
if ((not self.args.no_pr) or not self.args.staging_review) and (
1203+
self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function)
1204+
):
1205+
self.write_code_and_helpers(
1206+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1207+
)
1208+
11821209
def establish_original_code_baseline(
11831210
self,
11841211
code_context: CodeOptimizationContext,

0 commit comments

Comments
 (0)