Skip to content

Commit b6a68ed

Browse files
authored
Merge branch 'main' into replay-test-save-dir
2 parents cedf3f1 + 55a48eb commit b6a68ed

File tree

12 files changed

+315
-161
lines changed

12 files changed

+315
-161
lines changed

codeflash/api/aiservice.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from codeflash.cli_cmds.console import console, logger
1313
from codeflash.code_utils.env_utils import get_codeflash_api_key
14+
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
1415
from codeflash.models.models import OptimizedCandidate
1516
from codeflash.telemetry.posthog_cf import ph
1617
from codeflash.version import __version__ as codeflash_version
@@ -97,6 +98,12 @@ def optimize_python_code( # noqa: D417
9798
9899
"""
99100
start_time = time.perf_counter()
101+
try:
102+
git_repo_owner, git_repo_name = get_repo_owner_and_name()
103+
except Exception as e:
104+
logger.warning(f"Could not determine repo owner and name: {e}")
105+
git_repo_owner, git_repo_name = None, None
106+
100107
payload = {
101108
"source_code": source_code,
102109
"dependency_code": dependency_code,
@@ -105,6 +112,9 @@ def optimize_python_code( # noqa: D417
105112
"python_version": platform.python_version(),
106113
"experiment_metadata": experiment_metadata,
107114
"codeflash_version": codeflash_version,
115+
"current_username": get_last_commit_author_if_pr_exists(None),
116+
"repo_owner": git_repo_owner,
117+
"repo_name": git_repo_name,
108118
}
109119

110120
logger.info("Generating optimized candidates…")

codeflash/api/cfapi.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,14 @@ def add_code_context_hash(code_context_hash: str) -> None:
224224
"POST",
225225
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_hash": code_context_hash},
226226
)
227+
228+
229+
def mark_optimization_success(trace_id: str, *, is_optimization_found: bool) -> Response:
230+
"""Mark an optimization event as success or not.
231+
232+
:param trace_id: The unique identifier for the optimization event.
233+
:param is_optimization_found: Boolean indicating whether the optimization was found.
234+
:return: The response object from the API.
235+
"""
236+
payload = {"trace_id": trace_id, "is_optimization_found": is_optimization_found}
237+
return make_cfapi_request(endpoint="/mark-as-success", method="POST", payload=payload)

codeflash/cli_cmds/cli.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
99
from codeflash.cli_cmds.console import logger
1010
from codeflash.code_utils import env_utils
11+
from codeflash.code_utils.code_utils import exit_with_message
1112
from codeflash.code_utils.config_parser import parse_config_file
1213
from codeflash.version import __version__ as version
1314

@@ -42,7 +43,7 @@ def parse_args() -> Namespace:
4243
)
4344
parser.add_argument("--test-framework", choices=["pytest", "unittest"], default="pytest")
4445
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
45-
parser.add_argument("--replay-test", type=str, help="Path to replay test to optimize functions from")
46+
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
4647
parser.add_argument(
4748
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
4849
)
@@ -83,25 +84,22 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
8384
sys.exit()
8485
if not check_running_in_git_repo(module_root=args.module_root):
8586
if not confirm_proceeding_with_no_git_repo():
86-
logger.critical("No git repository detected and user aborted run. Exiting...")
87-
sys.exit(1)
87+
exit_with_message("No git repository detected and user aborted run. Exiting...", error_on_exit=True)
8888
args.no_pr = True
8989
if args.function and not args.file:
90-
logger.error("If you specify a --function, you must specify the --file it is in")
91-
sys.exit(1)
90+
exit_with_message("If you specify a --function, you must specify the --file it is in", error_on_exit=True)
9291
if args.file:
9392
if not Path(args.file).exists():
94-
logger.error(f"File {args.file} does not exist")
95-
sys.exit(1)
93+
exit_with_message(f"File {args.file} does not exist", error_on_exit=True)
9694
args.file = Path(args.file).resolve()
9795
if not args.no_pr:
9896
owner, repo = get_repo_owner_and_name()
9997
require_github_app_or_exit(owner, repo)
10098
if args.replay_test:
101-
if not Path(args.replay_test).is_file():
102-
logger.error(f"Replay test file {args.replay_test} does not exist")
103-
sys.exit(1)
104-
args.replay_test = Path(args.replay_test).resolve()
99+
for test_path in args.replay_test:
100+
if not Path(test_path).is_file():
101+
exit_with_message(f"Replay test file {test_path} does not exist", error_on_exit=True)
102+
args.replay_test = [Path(replay_test).resolve() for replay_test in args.replay_test]
105103

106104
return args
107105

@@ -110,8 +108,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
110108
try:
111109
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
112110
except ValueError as e:
113-
logger.error(e)
114-
sys.exit(1)
111+
exit_with_message(f"Error parsing config file: {e}", error_on_exit=True)
115112
supported_keys = [
116113
"module_root",
117114
"tests_root",
@@ -203,8 +200,7 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
203200
)
204201
apologize_and_exit()
205202
if not args.no_pr and not check_and_push_branch(git_repo):
206-
logger.critical("❌ Branch is not pushed. Exiting...")
207-
sys.exit(1)
203+
exit_with_message("Branch is not pushed...", error_on_exit=True)
208204
owner, repo = get_repo_owner_and_name(git_repo)
209205
if not args.no_pr:
210206
require_github_app_or_exit(owner, repo)

codeflash/code_utils/code_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import re
66
import shutil
77
import site
8+
import sys
89
from contextlib import contextmanager
910
from functools import lru_cache
1011
from pathlib import Path
1112
from tempfile import TemporaryDirectory
1213

1314
import tomlkit
1415

15-
from codeflash.cli_cmds.console import logger
16+
from codeflash.cli_cmds.console import logger, paneled_text
1617
from codeflash.code_utils.config_parser import find_pyproject_toml
1718

1819
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
@@ -33,7 +34,7 @@ def custom_addopts() -> None:
3334
# Backup original addopts
3435
original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "")
3536
# nothing to do if no addopts present
36-
if original_addopts != "":
37+
if original_addopts != "" and isinstance(original_addopts, list):
3738
original_addopts = [x.strip() for x in original_addopts]
3839
non_blacklist_plugin_args = re.sub(r"-n(?: +|=)\S+", "", " ".join(original_addopts)).split(" ")
3940
non_blacklist_plugin_args = [x for x in non_blacklist_plugin_args if x != ""]
@@ -213,3 +214,9 @@ def cleanup_paths(paths: list[Path]) -> None:
213214
def restore_conftest(path_to_content_map: dict[Path, str]) -> None:
214215
for path, file_content in path_to_content_map.items():
215216
path.write_text(file_content, encoding="utf8")
217+
218+
219+
def exit_with_message(message: str, *, error_on_exit: bool = False) -> None:
220+
paneled_text(message, panel_args={"style": "red"})
221+
222+
sys.exit(1 if error_on_exit else 0)

codeflash/code_utils/env_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

33
import os
4-
import sys
54
import tempfile
65
from functools import lru_cache
76
from pathlib import Path
87
from typing import Optional
98

109
from codeflash.cli_cmds.console import logger
10+
from codeflash.code_utils.code_utils import exit_with_message
1111
from codeflash.code_utils.formatter import format_code
1212
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
1313

@@ -24,11 +24,11 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
2424
try:
2525
format_code(formatter_cmds, tmp_file, print_status=False)
2626
except Exception:
27-
print(
28-
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again."
27+
exit_with_message(
28+
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.",
29+
error_on_exit=True,
2930
)
30-
if exit_on_failure:
31-
sys.exit(1)
31+
3232
return return_code
3333

3434

codeflash/code_utils/git_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
import shutil
45
import subprocess
56
import sys
@@ -176,3 +177,20 @@ def remove_git_worktrees(worktree_root: Path | None, worktrees: list[Path]) -> N
176177
logger.warning(f"Error removing worktrees: {e}")
177178
if worktree_root:
178179
shutil.rmtree(worktree_root)
180+
181+
182+
def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:
183+
"""Return the author's name of the last commit in the current branch if PR_NUMBER is set.
184+
185+
Otherwise, return None.
186+
"""
187+
if "PR_NUMBER" not in os.environ:
188+
return None
189+
try:
190+
repository: Repo = repo if repo else git.Repo(search_parent_directories=True)
191+
last_commit = repository.head.commit
192+
except Exception:
193+
logger.exception("Failed to get last commit author.")
194+
return None
195+
else:
196+
return last_commit.author.name

0 commit comments

Comments
 (0)