Skip to content

Commit 4bd757b

Browse files
authored
Merge branch 'main' into chore/handle-json-and-string-errors-from-cfapi
2 parents abe3017 + 55a48eb commit 4bd757b

File tree

18 files changed

+1578
-283
lines changed

18 files changed

+1578
-283
lines changed

.github/workflows/unit-tests.yaml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,4 @@ jobs:
3232
run: uvx poetry install --with dev
3333

3434
- name: Unit tests
35-
run: uvx poetry run pytest tests/ --cov --cov-report=xml --benchmark-skip -m "not ci_skip"
36-
37-
- name: Upload coverage reports to Codecov
38-
uses: codecov/codecov-action@v5
39-
if: matrix.python-version == '3.12.1'
40-
with:
41-
token: ${{ secrets.CODECOV_TOKEN }}
35+
run: uvx poetry run pytest tests/ --benchmark-skip -m "not ci_skip"

codeflash/api/aiservice.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from codeflash.cli_cmds.console import console, logger
1313
from codeflash.code_utils.env_utils import get_codeflash_api_key
14+
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
1415
from codeflash.models.models import OptimizedCandidate
1516
from codeflash.telemetry.posthog_cf import ph
1617
from codeflash.version import __version__ as codeflash_version
@@ -97,6 +98,12 @@ def optimize_python_code( # noqa: D417
9798
9899
"""
99100
start_time = time.perf_counter()
101+
try:
102+
git_repo_owner, git_repo_name = get_repo_owner_and_name()
103+
except Exception as e:
104+
logger.warning(f"Could not determine repo owner and name: {e}")
105+
git_repo_owner, git_repo_name = None, None
106+
100107
payload = {
101108
"source_code": source_code,
102109
"dependency_code": dependency_code,
@@ -105,6 +112,9 @@ def optimize_python_code( # noqa: D417
105112
"python_version": platform.python_version(),
106113
"experiment_metadata": experiment_metadata,
107114
"codeflash_version": codeflash_version,
115+
"current_username": get_last_commit_author_if_pr_exists(None),
116+
"repo_owner": git_repo_owner,
117+
"repo_name": git_repo_name,
108118
}
109119

110120
logger.info("Generating optimized candidates…")

codeflash/api/cfapi.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,14 @@ def add_code_context_hash(code_context_hash: str) -> None:
239239
"POST",
240240
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_hash": code_context_hash},
241241
)
242+
243+
244+
def mark_optimization_success(trace_id: str, *, is_optimization_found: bool) -> Response:
245+
"""Mark an optimization event as success or not.
246+
247+
:param trace_id: The unique identifier for the optimization event.
248+
:param is_optimization_found: Boolean indicating whether the optimization was found.
249+
:return: The response object from the API.
250+
"""
251+
payload = {"trace_id": trace_id, "is_optimization_found": is_optimization_found}
252+
return make_cfapi_request(endpoint="/mark-as-success", method="POST", payload=payload)

codeflash/cli_cmds/cli.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
99
from codeflash.cli_cmds.console import logger
1010
from codeflash.code_utils import env_utils
11+
from codeflash.code_utils.code_utils import exit_with_message
1112
from codeflash.code_utils.config_parser import parse_config_file
1213
from codeflash.version import __version__ as version
1314

@@ -42,7 +43,7 @@ def parse_args() -> Namespace:
4243
)
4344
parser.add_argument("--test-framework", choices=["pytest", "unittest"], default="pytest")
4445
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
45-
parser.add_argument("--replay-test", type=str, help="Path to replay test to optimize functions from")
46+
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
4647
parser.add_argument(
4748
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
4849
)
@@ -83,25 +84,22 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
8384
sys.exit()
8485
if not check_running_in_git_repo(module_root=args.module_root):
8586
if not confirm_proceeding_with_no_git_repo():
86-
logger.critical("No git repository detected and user aborted run. Exiting...")
87-
sys.exit(1)
87+
exit_with_message("No git repository detected and user aborted run. Exiting...", error_on_exit=True)
8888
args.no_pr = True
8989
if args.function and not args.file:
90-
logger.error("If you specify a --function, you must specify the --file it is in")
91-
sys.exit(1)
90+
exit_with_message("If you specify a --function, you must specify the --file it is in", error_on_exit=True)
9291
if args.file:
9392
if not Path(args.file).exists():
94-
logger.error(f"File {args.file} does not exist")
95-
sys.exit(1)
93+
exit_with_message(f"File {args.file} does not exist", error_on_exit=True)
9694
args.file = Path(args.file).resolve()
9795
if not args.no_pr:
9896
owner, repo = get_repo_owner_and_name()
9997
require_github_app_or_exit(owner, repo)
10098
if args.replay_test:
101-
if not Path(args.replay_test).is_file():
102-
logger.error(f"Replay test file {args.replay_test} does not exist")
103-
sys.exit(1)
104-
args.replay_test = Path(args.replay_test).resolve()
99+
for test_path in args.replay_test:
100+
if not Path(test_path).is_file():
101+
exit_with_message(f"Replay test file {test_path} does not exist", error_on_exit=True)
102+
args.replay_test = [Path(replay_test).resolve() for replay_test in args.replay_test]
105103

106104
return args
107105

@@ -110,8 +108,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
110108
try:
111109
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
112110
except ValueError as e:
113-
logger.error(e)
114-
sys.exit(1)
111+
exit_with_message(f"Error parsing config file: {e}", error_on_exit=True)
115112
supported_keys = [
116113
"module_root",
117114
"tests_root",
@@ -206,8 +203,7 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
206203
)
207204
apologize_and_exit()
208205
if not args.no_pr and not check_and_push_branch(git_repo):
209-
logger.critical("❌ Branch is not pushed. Exiting...")
210-
sys.exit(1)
206+
exit_with_message("Branch is not pushed...", error_on_exit=True)
211207
owner, repo = get_repo_owner_and_name(git_repo)
212208
if not args.no_pr:
213209
require_github_app_or_exit(owner, repo)

codeflash/code_utils/code_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import re
66
import shutil
77
import site
8+
import sys
89
from contextlib import contextmanager
910
from functools import lru_cache
1011
from pathlib import Path
1112
from tempfile import TemporaryDirectory
1213

1314
import tomlkit
1415

15-
from codeflash.cli_cmds.console import logger
16+
from codeflash.cli_cmds.console import logger, paneled_text
1617
from codeflash.code_utils.config_parser import find_pyproject_toml
1718

1819
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
@@ -33,7 +34,7 @@ def custom_addopts() -> None:
3334
# Backup original addopts
3435
original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "")
3536
# nothing to do if no addopts present
36-
if original_addopts != "":
37+
if original_addopts != "" and isinstance(original_addopts, list):
3738
original_addopts = [x.strip() for x in original_addopts]
3839
non_blacklist_plugin_args = re.sub(r"-n(?: +|=)\S+", "", " ".join(original_addopts)).split(" ")
3940
non_blacklist_plugin_args = [x for x in non_blacklist_plugin_args if x != ""]
@@ -213,3 +214,9 @@ def cleanup_paths(paths: list[Path]) -> None:
213214
def restore_conftest(path_to_content_map: dict[Path, str]) -> None:
214215
for path, file_content in path_to_content_map.items():
215216
path.write_text(file_content, encoding="utf8")
217+
218+
219+
def exit_with_message(message: str, *, error_on_exit: bool = False) -> None:
220+
paneled_text(message, panel_args={"style": "red"})
221+
222+
sys.exit(1 if error_on_exit else 0)

codeflash/code_utils/env_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

33
import os
4-
import sys
54
import tempfile
65
from functools import lru_cache
76
from pathlib import Path
87
from typing import Optional
98

109
from codeflash.cli_cmds.console import logger
10+
from codeflash.code_utils.code_utils import exit_with_message
1111
from codeflash.code_utils.formatter import format_code
1212
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
1313

@@ -24,11 +24,11 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
2424
try:
2525
format_code(formatter_cmds, tmp_file, print_status=False)
2626
except Exception:
27-
print(
28-
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again."
27+
exit_with_message(
28+
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.",
29+
error_on_exit=True,
2930
)
30-
if exit_on_failure:
31-
sys.exit(1)
31+
3232
return return_code
3333

3434

codeflash/code_utils/formatter.py

Lines changed: 108 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,78 @@
11
from __future__ import annotations
22

3+
import difflib
34
import os
5+
import re
46
import shlex
7+
import shutil
58
import subprocess
6-
from typing import TYPE_CHECKING
9+
import tempfile
10+
from pathlib import Path
11+
from typing import Optional, Union
712

813
import isort
914

1015
from codeflash.cli_cmds.console import console, logger
1116

12-
if TYPE_CHECKING:
13-
from pathlib import Path
1417

18+
def generate_unified_diff(original: str, modified: str, from_file: str, to_file: str) -> str:
19+
line_pattern = re.compile(r"(.*?(?:\r\n|\n|\r|$))")
1520

16-
def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa
21+
def split_lines(text: str) -> list[str]:
22+
lines = [match[0] for match in line_pattern.finditer(text)]
23+
if lines and lines[-1] == "":
24+
lines.pop()
25+
return lines
26+
27+
original_lines = split_lines(original)
28+
modified_lines = split_lines(modified)
29+
30+
diff_output = []
31+
for line in difflib.unified_diff(original_lines, modified_lines, fromfile=from_file, tofile=to_file, n=5):
32+
if line.endswith("\n"):
33+
diff_output.append(line)
34+
else:
35+
diff_output.append(line + "\n")
36+
diff_output.append("\\ No newline at end of file\n")
37+
38+
return "".join(diff_output)
39+
40+
41+
def apply_formatter_cmds(
42+
cmds: list[str],
43+
path: Path,
44+
test_dir_str: Optional[str],
45+
print_status: bool, # noqa
46+
) -> tuple[Path, str]:
1747
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
18-
formatter_name = formatter_cmds[0].lower()
48+
formatter_name = cmds[0].lower()
49+
should_make_copy = False
50+
file_path = path
51+
52+
if test_dir_str:
53+
should_make_copy = True
54+
file_path = Path(test_dir_str) / "temp.py"
55+
56+
if not cmds or formatter_name == "disabled":
57+
return path, path.read_text(encoding="utf8")
58+
1959
if not path.exists():
20-
msg = f"File {path} does not exist. Cannot format the file."
60+
msg = f"File {path} does not exist. Cannot apply formatter commands."
2161
raise FileNotFoundError(msg)
22-
if formatter_name == "disabled":
23-
return path.read_text(encoding="utf8")
62+
63+
if should_make_copy:
64+
shutil.copy2(path, file_path)
65+
2466
file_token = "$file" # noqa: S105
25-
for command in formatter_cmds:
67+
68+
for command in cmds:
2669
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
27-
formatter_cmd_list = [path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
70+
formatter_cmd_list = [file_path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
2871
try:
2972
result = subprocess.run(formatter_cmd_list, capture_output=True, check=False)
3073
if result.returncode == 0:
3174
if print_status:
32-
console.rule(f"Formatted Successfully with: {formatter_name.replace('$file', path.name)}")
75+
console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}")
3376
else:
3477
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
3578
except FileNotFoundError as e:
@@ -44,7 +87,60 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True
4487

4588
raise e from None
4689

47-
return path.read_text(encoding="utf8")
90+
return file_path, file_path.read_text(encoding="utf8")
91+
92+
93+
def get_diff_lines_count(diff_output: str) -> int:
94+
lines = diff_output.split("\n")
95+
96+
def is_diff_line(line: str) -> bool:
97+
return line.startswith(("+", "-")) and not line.startswith(("+++", "---"))
98+
99+
diff_lines = [line for line in lines if is_diff_line(line)]
100+
return len(diff_lines)
101+
102+
103+
def format_code(
104+
formatter_cmds: list[str],
105+
path: Union[str, Path],
106+
optimized_function: str = "",
107+
check_diff: bool = False, # noqa
108+
print_status: bool = True, # noqa
109+
) -> str:
110+
with tempfile.TemporaryDirectory() as test_dir_str:
111+
if isinstance(path, str):
112+
path = Path(path)
113+
114+
original_code = path.read_text(encoding="utf8")
115+
original_code_lines = len(original_code.split("\n"))
116+
117+
if check_diff and original_code_lines > 50:
118+
# we dont' count the formatting diff for the optimized function as it should be well-formatted
119+
original_code_without_opfunc = original_code.replace(optimized_function, "")
120+
121+
original_temp = Path(test_dir_str) / "original_temp.py"
122+
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
123+
124+
formatted_temp, formatted_code = apply_formatter_cmds(
125+
formatter_cmds, original_temp, test_dir_str, print_status=False
126+
)
127+
128+
diff_output = generate_unified_diff(
129+
original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp)
130+
)
131+
diff_lines_count = get_diff_lines_count(diff_output)
132+
133+
max_diff_lines = min(int(original_code_lines * 0.3), 50)
134+
135+
if diff_lines_count > max_diff_lines and max_diff_lines != -1:
136+
logger.debug(
137+
f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})"
138+
)
139+
return original_code
140+
# TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above.
141+
_, formatted_code = apply_formatter_cmds(formatter_cmds, path, test_dir_str=None, print_status=print_status)
142+
logger.debug(f"Formatted {path} with commands: {formatter_cmds}")
143+
return formatted_code
48144

49145

50146
def sort_imports(code: str) -> str:

codeflash/code_utils/git_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
import shutil
45
import subprocess
56
import sys
@@ -176,3 +177,20 @@ def remove_git_worktrees(worktree_root: Path | None, worktrees: list[Path]) -> N
176177
logger.warning(f"Error removing worktrees: {e}")
177178
if worktree_root:
178179
shutil.rmtree(worktree_root)
180+
181+
182+
def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:
183+
"""Return the author's name of the last commit in the current branch if PR_NUMBER is set.
184+
185+
Otherwise, return None.
186+
"""
187+
if "PR_NUMBER" not in os.environ:
188+
return None
189+
try:
190+
repository: Repo = repo if repo else git.Repo(search_parent_directories=True)
191+
last_commit = repository.head.commit
192+
except Exception:
193+
logger.exception("Failed to get last commit author.")
194+
return None
195+
else:
196+
return last_commit.author.name

0 commit comments

Comments
 (0)