Skip to content

Commit 2e46088

Browse files
authored
Merge pull request #480 from codeflash-ai/feat-staging
[Feat] add staging argument
2 parents e806c5d + 1ac52b0 commit 2e46088

File tree

3 files changed

+149
-64
lines changed

3 files changed

+149
-64
lines changed

codeflash/api/cfapi.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
from codeflash.cli_cmds.console import console, logger
1616
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
17-
from codeflash.code_utils.git_utils import get_repo_owner_and_name
17+
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name, git_root_dir
18+
from codeflash.github.PrComment import FileDiffContent, PrComment
1819
from codeflash.version import __version__
1920

2021
if TYPE_CHECKING:
2122
from requests import Response
2223

23-
from codeflash.github.PrComment import FileDiffContent, PrComment
24+
from codeflash.result.explanation import Explanation
25+
2426
from packaging import version
2527

2628
if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local":
@@ -182,6 +184,59 @@ def create_pr(
182184
return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload)
183185

184186

187+
def create_staging(
188+
original_code: dict[Path, str],
189+
new_code: dict[Path, str],
190+
explanation: Explanation,
191+
existing_tests_source: str,
192+
generated_original_test_source: str,
193+
function_trace_id: str,
194+
coverage_message: str,
195+
) -> Response:
196+
"""Create a staging pull request, targeting the specified branch. (usually 'staging').
197+
198+
:param original_code: A mapping of file paths to original source code.
199+
:param new_code: A mapping of file paths to optimized source code.
200+
:param explanation: An Explanation object with optimization details.
201+
:param existing_tests_source: Existing test code.
202+
:param generated_original_test_source: Generated tests for the original function.
203+
:param function_trace_id: Unique identifier for this optimization trace.
204+
:param coverage_message: Coverage report or summary.
205+
:return: The response object from the backend.
206+
"""
207+
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
208+
209+
build_file_changes = {
210+
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
211+
oldContent=original_code[p], newContent=new_code[p]
212+
)
213+
for p in original_code
214+
}
215+
216+
payload = {
217+
"baseBranch": get_current_branch(),
218+
"diffContents": build_file_changes,
219+
"prCommentFields": PrComment(
220+
optimization_explanation=explanation.explanation_message(),
221+
best_runtime=explanation.best_runtime_ns,
222+
original_runtime=explanation.original_runtime_ns,
223+
function_name=explanation.function_name,
224+
relative_file_path=relative_path,
225+
speedup_x=explanation.speedup_x,
226+
speedup_pct=explanation.speedup_pct,
227+
winning_behavior_test_results=explanation.winning_behavior_test_results,
228+
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
229+
benchmark_details=explanation.benchmark_details,
230+
).to_json(),
231+
"existingTests": existing_tests_source,
232+
"generatedTests": generated_original_test_source,
233+
"traceId": function_trace_id,
234+
"coverage_message": coverage_message,
235+
}
236+
237+
return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload)
238+
239+
185240
def is_github_app_installed_on_repo(owner: str, repo: str, *, suppress_errors: bool = False) -> bool:
186241
"""Check if the Codeflash GitHub App is installed on the specified repository.
187242

codeflash/cli_cmds/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def parse_args() -> Namespace:
7777
parser.add_argument(
7878
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
7979
)
80+
parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review")
8081
parser.add_argument(
8182
"--verify-setup",
8283
action="store_true",

codeflash/optimization/function_optimizer.py

Lines changed: 91 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from rich.tree import Tree
2020

2121
from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient
22-
from codeflash.api.cfapi import add_code_context_hash, mark_optimization_success
22+
from codeflash.api.cfapi import add_code_context_hash, create_staging, mark_optimization_success
2323
from codeflash.benchmarking.utils import process_benchmark_data
2424
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
2525
from codeflash.code_utils import env_utils
@@ -1088,70 +1088,99 @@ def find_and_process_best_optimization(
10881088
original_code_combined[explanation.file_path] = self.function_to_optimize_source_code
10891089
new_code_combined = new_helper_code.copy()
10901090
new_code_combined[explanation.file_path] = new_code
1091-
if not self.args.no_pr:
1092-
coverage_message = (
1093-
original_code_baseline.coverage_results.build_message()
1094-
if original_code_baseline.coverage_results
1095-
else "Coverage data not available"
1096-
)
1097-
generated_tests = remove_functions_from_generated_tests(
1098-
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
1099-
)
1100-
original_runtime_by_test = (
1101-
original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
1102-
)
1103-
optimized_runtime_by_test = (
1104-
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
1105-
)
1106-
qualified_name = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
1107-
# Add runtime comments to generated tests before creating the PR
1108-
generated_tests = add_runtime_comments_to_generated_tests(
1109-
generated_tests, original_runtime_by_test, optimized_runtime_by_test
1110-
)
1111-
generated_tests_str = "\n\n".join(
1112-
[test.generated_original_test_source for test in generated_tests.generated_tests]
1113-
)
1114-
existing_tests = existing_tests_source_for(
1115-
qualified_name,
1116-
function_to_all_tests,
1117-
test_cfg=self.test_cfg,
1118-
original_runtimes_all=original_runtime_by_test,
1119-
optimized_runtimes_all=optimized_runtime_by_test,
1120-
)
1121-
if concolic_test_str:
1122-
generated_tests_str += "\n\n" + concolic_test_str
1123-
1124-
check_create_pr(
1125-
original_code=original_code_combined,
1126-
new_code=new_code_combined,
1127-
explanation=explanation,
1128-
existing_tests_source=existing_tests,
1129-
generated_original_test_source=generated_tests_str,
1130-
function_trace_id=self.function_trace_id[:-4] + exp_type
1131-
if self.experiment_id
1132-
else self.function_trace_id,
1133-
coverage_message=coverage_message,
1134-
git_remote=self.args.git_remote,
1135-
)
1136-
if (
1137-
self.args.all
1138-
or env_utils.get_pr_number()
1139-
or self.args.replay_test
1140-
or (self.args.file and not self.args.function)
1141-
):
1142-
self.write_code_and_helpers(
1143-
self.function_to_optimize_source_code,
1144-
original_helper_code,
1145-
self.function_to_optimize.file_path,
1146-
)
1147-
else:
1148-
# Mark optimization success since no PR will be created
1149-
mark_optimization_success(
1150-
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
1151-
)
1091+
self.process_review(
1092+
original_code_baseline,
1093+
best_optimization,
1094+
generated_tests,
1095+
test_functions_to_remove,
1096+
concolic_test_str,
1097+
original_code_combined,
1098+
new_code_combined,
1099+
explanation,
1100+
function_to_all_tests,
1101+
exp_type,
1102+
original_helper_code,
1103+
)
11521104
self.log_successful_optimization(explanation, generated_tests, exp_type)
11531105
return best_optimization
11541106

1107+
def process_review(
1108+
self,
1109+
original_code_baseline: OriginalCodeBaseline,
1110+
best_optimization: BestOptimization,
1111+
generated_tests: GeneratedTestsList,
1112+
test_functions_to_remove: list[str],
1113+
concolic_test_str: str | None,
1114+
original_code_combined: dict[Path, str],
1115+
new_code_combined: dict[Path, str],
1116+
explanation: Explanation,
1117+
function_to_all_tests: dict[str, set[FunctionCalledInTest]],
1118+
exp_type: str,
1119+
original_helper_code: dict[Path, str],
1120+
) -> None:
1121+
coverage_message = (
1122+
original_code_baseline.coverage_results.build_message()
1123+
if original_code_baseline.coverage_results
1124+
else "Coverage data not available"
1125+
)
1126+
1127+
generated_tests = remove_functions_from_generated_tests(
1128+
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
1129+
)
1130+
1131+
original_runtime_by_test = original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
1132+
optimized_runtime_by_test = (
1133+
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
1134+
)
1135+
1136+
generated_tests = add_runtime_comments_to_generated_tests(
1137+
generated_tests, original_runtime_by_test, optimized_runtime_by_test
1138+
)
1139+
1140+
generated_tests_str = "\n\n".join(
1141+
[test.generated_original_test_source for test in generated_tests.generated_tests]
1142+
)
1143+
if concolic_test_str:
1144+
generated_tests_str += "\n\n" + concolic_test_str
1145+
1146+
existing_tests = existing_tests_source_for(
1147+
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
1148+
function_to_all_tests,
1149+
test_cfg=self.test_cfg,
1150+
original_runtimes_all=original_runtime_by_test,
1151+
optimized_runtimes_all=optimized_runtime_by_test,
1152+
)
1153+
1154+
data = {
1155+
"original_code": original_code_combined,
1156+
"new_code": new_code_combined,
1157+
"explanation": explanation,
1158+
"existing_tests_source": existing_tests,
1159+
"generated_original_test_source": generated_tests_str,
1160+
"function_trace_id": self.function_trace_id[:-4] + exp_type
1161+
if self.experiment_id
1162+
else self.function_trace_id,
1163+
"coverage_message": coverage_message,
1164+
}
1165+
1166+
if not self.args.no_pr and not self.args.staging_review:
1167+
data["git_remote"] = self.args.git_remote
1168+
check_create_pr(**data)
1169+
elif self.args.staging_review:
1170+
create_staging(**data)
1171+
else:
1172+
# Mark optimization success since no PR will be created
1173+
mark_optimization_success(
1174+
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
1175+
)
1176+
1177+
if ((not self.args.no_pr) or not self.args.staging_review) and (
1178+
self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function)
1179+
):
1180+
self.write_code_and_helpers(
1181+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1182+
)
1183+
11551184
def establish_original_code_baseline(
11561185
self,
11571186
code_context: CodeOptimizationContext,

0 commit comments

Comments
 (0)