Skip to content

Commit 86bb1b5

Browse files
committed
create new staging argument
1 parent 2b5fa6e commit 86bb1b5

File tree

3 files changed

+125
-58
lines changed

3 files changed

+125
-58
lines changed

codeflash/api/cfapi.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414

1515
from codeflash.cli_cmds.console import console, logger
1616
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
17-
from codeflash.code_utils.git_utils import get_repo_owner_and_name
17+
from codeflash.code_utils.git_utils import get_repo_owner_and_name, git_root_dir, get_current_branch
18+
from codeflash.result.explanation import Explanation
1819
from codeflash.version import __version__
20+
from codeflash.github.PrComment import FileDiffContent, PrComment
1921

2022
if TYPE_CHECKING:
2123
from requests import Response
2224

23-
from codeflash.github.PrComment import FileDiffContent, PrComment
2425
from packaging import version
2526

2627
if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local":
@@ -174,6 +175,55 @@ def create_pr(
174175
}
175176
return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload)
176177

178+
def create_staging(
179+
original_code:str,
180+
new_code: str,
181+
explanation: Explanation,
182+
existing_tests_source: str,
183+
generated_original_test_source: str,
184+
function_trace_id: str,
185+
coverage_message: str,
186+
) -> Response:
187+
"""Create a staging pull request, targeting the specified branch. (usually 'staging').
188+
189+
:param owner: The owner of the repository.
190+
:param repo: The name of the repository.
191+
:param base_branch: The base branch to target.
192+
:param file_changes: A dictionary of file changes.
193+
:param pr_comment: The pull request comment object, containing the optimization explanation, best runtime, etc.
194+
:param generated_tests: The generated tests.
195+
:return: The response object.
196+
"""
197+
# convert Path objects to strings
198+
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
199+
200+
build_file_changes = {
201+
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
202+
oldContent=original_code[p], newContent=new_code[p]
203+
)
204+
for p in original_code
205+
}
206+
payload = {
207+
"baseBranch": get_current_branch(),
208+
"diffContents": build_file_changes,
209+
"prCommentFields": PrComment(
210+
optimization_explanation=explanation.explanation_message(),
211+
best_runtime=explanation.best_runtime_ns,
212+
original_runtime=explanation.original_runtime_ns,
213+
function_name=explanation.function_name,
214+
relative_file_path=relative_path,
215+
speedup_x=explanation.speedup_x,
216+
speedup_pct=explanation.speedup_pct,
217+
winning_behavioral_test_results=explanation.winning_behavioral_test_results,
218+
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
219+
benchmark_details=explanation.benchmark_details,
220+
).to_json(),
221+
"existingTests": existing_tests_source,
222+
"generatedTests": generated_original_test_source,
223+
"traceId": function_trace_id,
224+
"coverage_message": coverage_message,
225+
}
226+
return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload)
177227

178228
def is_github_app_installed_on_repo(owner: str, repo: str) -> bool:
179229
"""Check if the Codeflash GitHub App is installed on the specified repository.

codeflash/cli_cmds/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def parse_args() -> Namespace:
4747
parser.add_argument(
4848
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
4949
)
50+
parser.add_argument(
51+
"--staging-review",
52+
action="store_true",
53+
help="Upload optimizations to staging for review"
54+
)
5055
parser.add_argument(
5156
"--verify-setup",
5257
action="store_true",

codeflash/optimization/function_optimizer.py

Lines changed: 68 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from rich.tree import Tree
2020

2121
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
22-
from codeflash.api.cfapi import add_code_context_hash, mark_optimization_success
22+
from codeflash.api.cfapi import add_code_context_hash, mark_optimization_success, create_staging
2323
from codeflash.benchmarking.utils import process_benchmark_data
2424
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
2525
from codeflash.code_utils import env_utils
@@ -997,64 +997,76 @@ def find_and_process_best_optimization(
997997
original_code_combined[explanation.file_path] = self.function_to_optimize_source_code
998998
new_code_combined = new_helper_code.copy()
999999
new_code_combined[explanation.file_path] = new_code
1000-
if not self.args.no_pr:
1001-
coverage_message = (
1002-
original_code_baseline.coverage_results.build_message()
1003-
if original_code_baseline.coverage_results
1004-
else "Coverage data not available"
1005-
)
1006-
generated_tests = remove_functions_from_generated_tests(
1007-
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
1008-
)
1009-
original_runtime_by_test = (
1010-
original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
1011-
)
1012-
optimized_runtime_by_test = (
1013-
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
1014-
)
1015-
# Add runtime comments to generated tests before creating the PR
1016-
generated_tests = add_runtime_comments_to_generated_tests(
1017-
self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test
1018-
)
1019-
generated_tests_str = "\n\n".join(
1020-
[test.generated_original_test_source for test in generated_tests.generated_tests]
1021-
)
1022-
existing_tests = existing_tests_source_for(
1023-
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
1024-
function_to_all_tests,
1025-
test_cfg=self.test_cfg,
1026-
original_runtimes_all=original_runtime_by_test,
1027-
optimized_runtimes_all=optimized_runtime_by_test,
1028-
)
1029-
if concolic_test_str:
1030-
generated_tests_str += "\n\n" + concolic_test_str
1031-
1032-
check_create_pr(
1033-
original_code=original_code_combined,
1034-
new_code=new_code_combined,
1035-
explanation=explanation,
1036-
existing_tests_source=existing_tests,
1037-
generated_original_test_source=generated_tests_str,
1038-
function_trace_id=self.function_trace_id[:-4] + exp_type
1039-
if self.experiment_id
1040-
else self.function_trace_id,
1041-
coverage_message=coverage_message,
1042-
git_remote=self.args.git_remote,
1043-
)
1044-
if self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function):
1045-
self.write_code_and_helpers(
1046-
self.function_to_optimize_source_code,
1047-
original_helper_code,
1048-
self.function_to_optimize.file_path,
1049-
)
1050-
else:
1051-
# Mark optimization success since no PR will be created
1052-
mark_optimization_success(
1053-
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
1054-
)
1000+
self.process_review(original_code_baseline, best_optimization, generated_tests,
1001+
test_functions_to_remove, concolic_test_str, original_code_combined, new_code_combined,
1002+
explanation, function_to_all_tests, exp_type, original_helper_code)
10551003
self.log_successful_optimization(explanation, generated_tests, exp_type)
10561004
return best_optimization
10571005

1006+
def process_review(
1007+
self, original_code_baseline, best_optimization, generated_tests,
1008+
test_functions_to_remove, concolic_test_str, original_code_combined,
1009+
new_code_combined, explanation, function_to_all_tests, exp_type, original_helper_code
1010+
):
1011+
coverage_message = (
1012+
original_code_baseline.coverage_results.build_message()
1013+
if original_code_baseline.coverage_results else "Coverage data not available"
1014+
)
1015+
1016+
generated_tests = remove_functions_from_generated_tests(
1017+
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
1018+
)
1019+
1020+
original_runtime_by_test = original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
1021+
optimized_runtime_by_test = best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
1022+
1023+
generated_tests = add_runtime_comments_to_generated_tests(
1024+
self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test
1025+
)
1026+
1027+
generated_tests_str = "\n\n".join(
1028+
[test.generated_original_test_source for test in generated_tests.generated_tests]
1029+
)
1030+
if concolic_test_str:
1031+
generated_tests_str += "\n\n" + concolic_test_str
1032+
1033+
existing_tests = existing_tests_source_for(
1034+
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
1035+
function_to_all_tests,
1036+
test_cfg=self.test_cfg,
1037+
original_runtimes_all=original_runtime_by_test,
1038+
optimized_runtimes_all=optimized_runtime_by_test,
1039+
)
1040+
1041+
data = dict(
1042+
original_code=original_code_combined,
1043+
new_code=new_code_combined,
1044+
explanation=explanation,
1045+
existing_tests_source=existing_tests,
1046+
generated_original_test_source=generated_tests_str,
1047+
function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
1048+
coverage_message=coverage_message,
1049+
)
1050+
1051+
if not self.args.no_pr and not self.args.staging_review:
1052+
data["git_remote"] = self.args.git_remote
1053+
check_create_pr(**data)
1054+
elif self.args.staging_review:
1055+
create_staging(**data)
1056+
else:
1057+
# Mark optimization success since no PR will be created
1058+
mark_optimization_success(
1059+
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
1060+
)
1061+
1062+
if (not self.args.no_pr) or not self.args.staging_review:
1063+
if self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function):
1064+
self.write_code_and_helpers(
1065+
self.function_to_optimize_source_code,
1066+
original_helper_code,
1067+
self.function_to_optimize.file_path,
1068+
)
1069+
10581070
def establish_original_code_baseline(
10591071
self,
10601072
code_context: CodeOptimizationContext,

0 commit comments

Comments
 (0)