Skip to content

Commit 20567de

Browse files
authored
Merge branch 'main' into saga4/coverage_undefined_error
2 parents a941272 + e1d8fe0 commit 20567de

24 files changed

+365
-128
lines changed

codeflash/api/cfapi.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,6 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
179179
if pr_number is None:
180180
return {}
181181

182-
not_found = 404
183-
internal_server_error = 500
184-
185182
owner, repo = get_repo_owner_and_name()
186183
information = {"pr_number": pr_number, "repo_owner": owner, "repo_name": repo}
187184
try:

codeflash/benchmarking/replay_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ def create_trace_replay_test_code(
115115
if function_name == "__init__":
116116
ret = {class_name_alias}(*args[1:], **kwargs)
117117
else:
118-
instance = args[0] # self
119-
ret = instance{method_name}(*args[1:], **kwargs)
118+
ret = {class_name_alias}{method_name}(*args, **kwargs)
120119
"""
121120
)
122121

codeflash/cli_cmds/cli.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,12 @@
33
from argparse import SUPPRESS, ArgumentParser, Namespace
44
from pathlib import Path
55

6-
import git
7-
86
from codeflash.cli_cmds import logging_config
97
from codeflash.cli_cmds.cli_common import apologize_and_exit
108
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
119
from codeflash.cli_cmds.console import logger
1210
from codeflash.code_utils import env_utils
1311
from codeflash.code_utils.config_parser import parse_config_file
14-
from codeflash.code_utils.git_utils import (
15-
check_and_push_branch,
16-
check_running_in_git_repo,
17-
confirm_proceeding_with_no_git_repo,
18-
get_repo_owner_and_name,
19-
)
20-
from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit
2112
from codeflash.version import __version__ as version
2213

2314

@@ -75,6 +66,13 @@ def parse_args() -> Namespace:
7566

7667

7768
def process_and_validate_cmd_args(args: Namespace) -> Namespace:
69+
from codeflash.code_utils.git_utils import (
70+
check_running_in_git_repo,
71+
confirm_proceeding_with_no_git_repo,
72+
get_repo_owner_and_name,
73+
)
74+
from codeflash.code_utils.github_utils import require_github_app_or_exit
75+
7876
is_init: bool = args.command.startswith("init") if args.command else False
7977
if args.verbose:
8078
logging_config.set_level(logging.DEBUG, echo_setting=not is_init)
@@ -144,21 +142,26 @@ def process_pyproject_config(args: Namespace) -> Namespace:
144142
assert Path(args.benchmarks_root).resolve().is_relative_to(Path(args.tests_root).resolve()), (
145143
f"--benchmarks-root {args.benchmarks_root} must be a subdirectory of --tests-root {args.tests_root}"
146144
)
147-
if env_utils.get_pr_number() is not None:
148-
assert env_utils.ensure_codeflash_api_key(), (
149-
"Codeflash API key not found. When running in a Github Actions Context, provide the "
150-
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
151-
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
152-
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
153-
f"Here's a direct link: {get_github_secrets_page_url()}\n"
154-
"Exiting..."
155-
)
145+
if env_utils.get_pr_number() is not None:
146+
import git
147+
148+
from codeflash.code_utils.git_utils import get_repo_owner_and_name
149+
from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit
150+
151+
assert env_utils.ensure_codeflash_api_key(), (
152+
"Codeflash API key not found. When running in a Github Actions Context, provide the "
153+
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
154+
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
155+
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
156+
f"Here's a direct link: {get_github_secrets_page_url()}\n"
157+
"Exiting..."
158+
)
156159

157-
repo = git.Repo(search_parent_directories=True)
160+
repo = git.Repo(search_parent_directories=True)
158161

159-
owner, repo_name = get_repo_owner_and_name(repo)
162+
owner, repo_name = get_repo_owner_and_name(repo)
160163

161-
require_github_app_or_exit(owner, repo_name)
164+
require_github_app_or_exit(owner, repo_name)
162165

163166
if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
164167
normalized_ignore_paths = []
@@ -187,6 +190,11 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)
187190

188191
def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
189192
if hasattr(args, "all"):
193+
import git
194+
195+
from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
196+
from codeflash.code_utils.github_utils import require_github_app_or_exit
197+
190198
# Ensure that the user can actually open PRs on the repo.
191199
try:
192200
git_repo = git.Repo(search_parent_directories=True)

codeflash/cli_cmds/cmd_init.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,8 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n
514514
from importlib.resources import files
515515

516516
benchmark_mode = False
517-
if "benchmarks_root" in config:
517+
benchmarks_root = config.get("benchmarks_root", "").strip()
518+
if benchmarks_root and benchmarks_root != "":
518519
benchmark_mode = inquirer_wrapper(
519520
inquirer.confirm,
520521
message="⚡️It looks like you've configured a benchmarks_root in your config. Would you like to run the Github action in benchmark mode? "
@@ -720,7 +721,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
720721
)
721722
elif formatter == "don't use a formatter":
722723
formatter_cmds.append("disabled")
723-
check_formatter_installed(formatter_cmds)
724+
check_formatter_installed(formatter_cmds, exit_on_failure=False)
724725
codeflash_section["formatter-cmds"] = formatter_cmds
725726
# Add the 'codeflash' section, ensuring 'tool' section exists
726727
tool_section = pyproject_data.get("tool", tomlkit.table())

codeflash/code_utils/code_extractor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import TYPE_CHECKING, Optional
66

77
import libcst as cst
8-
import libcst.matchers as m
98
from libcst.codemod import CodemodContext
109
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
1110
from libcst.helpers import calculate_module_and_package
@@ -248,6 +247,8 @@ class FutureAliasedImportTransformer(cst.CSTTransformer):
248247
def leave_ImportFrom(
249248
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
250249
) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
250+
import libcst.matchers as m
251+
251252
if (
252253
(updated_node_module := updated_node.module)
253254
and updated_node_module.value == "__future__"

codeflash/code_utils/code_utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,87 @@
22

33
import ast
44
import os
5+
import re
56
import shutil
67
import site
8+
from contextlib import contextmanager
79
from functools import lru_cache
810
from pathlib import Path
911
from tempfile import TemporaryDirectory
1012

13+
import tomlkit
14+
1115
from codeflash.cli_cmds.console import logger
16+
from codeflash.code_utils.config_parser import find_pyproject_toml
17+
18+
19+
@contextmanager
20+
def custom_addopts() -> None:
21+
pyproject_file = find_pyproject_toml()
22+
original_content = None
23+
non_blacklist_plugin_args = ""
24+
25+
try:
26+
# Read original file
27+
if pyproject_file.exists():
28+
with Path.open(pyproject_file, encoding="utf-8") as f:
29+
original_content = f.read()
30+
data = tomlkit.parse(original_content)
31+
# Backup original addopts
32+
original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "")
33+
# nothing to do if no addopts present
34+
if original_addopts != "":
35+
original_addopts = [x.strip() for x in original_addopts]
36+
non_blacklist_plugin_args = re.sub(r"-n(?: +|=)\S+", "", " ".join(original_addopts)).split(" ")
37+
non_blacklist_plugin_args = [x for x in non_blacklist_plugin_args if x != ""]
38+
if non_blacklist_plugin_args != original_addopts:
39+
data["tool"]["pytest"]["ini_options"]["addopts"] = non_blacklist_plugin_args
40+
# Write modified file
41+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
42+
f.write(tomlkit.dumps(data))
43+
44+
yield
45+
46+
finally:
47+
# Restore original file
48+
if (
49+
original_content
50+
and pyproject_file.exists()
51+
and tuple(original_addopts) not in {(), tuple(non_blacklist_plugin_args)}
52+
):
53+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
54+
f.write(original_content)
55+
56+
57+
@contextmanager
58+
def add_addopts_to_pyproject() -> None:
59+
pyproject_file = find_pyproject_toml()
60+
original_content = None
61+
try:
62+
# Read original file
63+
if pyproject_file.exists():
64+
with Path.open(pyproject_file, encoding="utf-8") as f:
65+
original_content = f.read()
66+
data = tomlkit.parse(original_content)
67+
data["tool"]["pytest"] = {}
68+
data["tool"]["pytest"]["ini_options"] = {}
69+
data["tool"]["pytest"]["ini_options"]["addopts"] = [
70+
"-n=auto",
71+
"-n",
72+
"1",
73+
"-n 1",
74+
"-n 1",
75+
"-n auto",
76+
]
77+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
78+
f.write(tomlkit.dumps(data))
79+
80+
yield
81+
82+
finally:
83+
# Restore original file
84+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
85+
f.write(original_content)
1286

1387

1488
def encoded_tokens_len(s: str) -> int:

codeflash/code_utils/config_consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
N_TESTS_TO_GENERATE = 2
99
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
1010
COVERAGE_THRESHOLD = 60.0
11+
MIN_TESTCASE_PASSED_THRESHOLD = 6

codeflash/code_utils/env_utils.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,18 @@
11
from __future__ import annotations
22

33
import os
4-
import shlex
5-
import subprocess
4+
import sys
65
import tempfile
76
from functools import lru_cache
87
from pathlib import Path
98
from typing import Optional
109

1110
from codeflash.cli_cmds.console import logger
11+
from codeflash.code_utils.formatter import format_code
1212
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
1313

1414

15-
class FormatterNotFoundError(Exception):
16-
"""Exception raised when a formatter is not found."""
17-
18-
def __init__(self, formatter_cmd: str) -> None:
19-
super().__init__(f"Formatter command not found: {formatter_cmd}")
20-
21-
22-
def check_formatter_installed(formatter_cmds: list[str]) -> bool:
15+
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
2316
return_code = True
2417
if formatter_cmds[0] == "disabled":
2518
return return_code
@@ -28,22 +21,14 @@ def check_formatter_installed(formatter_cmds: list[str]) -> bool:
2821
f.write(tmp_code)
2922
f.flush()
3023
tmp_file = Path(f.name)
31-
file_token = "$file" # noqa: S105
32-
for command in set(formatter_cmds):
33-
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
34-
formatter_cmd_list = [tmp_file.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
35-
try:
36-
result = subprocess.run(formatter_cmd_list, capture_output=True, check=False)
37-
except (FileNotFoundError, NotADirectoryError):
38-
return_code = False
39-
break
40-
if result.returncode:
41-
return_code = False
42-
break
43-
tmp_file.unlink(missing_ok=True)
44-
if not return_code:
45-
msg = f"Error running formatter command: {command}"
46-
raise FormatterNotFoundError(msg)
24+
try:
25+
format_code(formatter_cmds, tmp_file)
26+
except Exception:
27+
print(
28+
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again."
29+
)
30+
if exit_on_failure:
31+
sys.exit(1)
4732
return return_code
4833

4934

codeflash/code_utils/formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def format_code(formatter_cmds: list[str], path: Path) -> str:
2222
if formatter_name == "disabled":
2323
return path.read_text(encoding="utf8")
2424
file_token = "$file" # noqa: S105
25-
for command in set(formatter_cmds):
25+
for command in formatter_cmds:
2626
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
2727
formatter_cmd_list = [path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
2828
try:

codeflash/context/code_context_extractor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from collections import defaultdict
55
from itertools import chain
66
from pathlib import Path # noqa: TC003
7+
from typing import TYPE_CHECKING
78

8-
import jedi
99
import libcst as cst
10-
from jedi.api.classes import Name # noqa: TC002
1110
from libcst import CSTNode # noqa: TC002
1211

1312
from codeflash.cli_cmds.console import logger
@@ -24,6 +23,9 @@
2423
)
2524
from codeflash.optimization.function_context import belongs_to_function_qualified
2625

26+
if TYPE_CHECKING:
27+
from jedi.api.classes import Name
28+
2729

2830
def get_code_optimization_context(
2931
function_to_optimize: FunctionToOptimize,
@@ -354,6 +356,8 @@ def extract_code_markdown_context_from_files(
354356
def get_function_to_optimize_as_function_source(
355357
function_to_optimize: FunctionToOptimize, project_root_path: Path
356358
) -> FunctionSource:
359+
import jedi
360+
357361
# Use jedi to find function to optimize
358362
script = jedi.Script(path=function_to_optimize.file_path, project=jedi.Project(path=project_root_path))
359363

@@ -389,6 +393,8 @@ def get_function_to_optimize_as_function_source(
389393
def get_function_sources_from_jedi(
390394
file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path
391395
) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]:
396+
import jedi
397+
392398
file_path_to_function_source = defaultdict(set)
393399
function_source_list: list[FunctionSource] = []
394400
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():

0 commit comments

Comments
 (0)