Skip to content

Commit 22a99ce

Browse files
Merge branch 'main' into deferred-imports
2 parents 58fe58e + 5185423 commit 22a99ce

File tree

16 files changed

+306
-75
lines changed

16 files changed

+306
-75
lines changed

codeflash/api/cfapi.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -179,19 +179,10 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
179179
if pr_number is None:
180180
return {}
181181

182-
not_found = 404
183-
internal_server_error = 500
184-
185182
owner, repo = get_repo_owner_and_name()
186183
information = {"pr_number": pr_number, "repo_owner": owner, "repo_name": repo}
187184
try:
188185
req = make_cfapi_request(endpoint="/verify-existing-optimizations", method="POST", payload=information)
189-
if req.status_code == not_found:
190-
logger.debug(req.json()["message"])
191-
return {}
192-
if req.status_code == internal_server_error:
193-
logger.error(req.json()["message"])
194-
return {}
195186
req.raise_for_status()
196187
content: dict[str, list[str]] = req.json()
197188
except Exception as e:

codeflash/cli_cmds/cmd_init.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from codeflash.cli_cmds.console import console, logger
2323
from codeflash.code_utils.compat import LF
2424
from codeflash.code_utils.config_parser import parse_config_file
25-
from codeflash.code_utils.env_utils import get_codeflash_api_key
25+
from codeflash.code_utils.env_utils import check_formatter_installed, get_codeflash_api_key
2626
from codeflash.code_utils.git_utils import get_git_remotes, get_repo_owner_and_name
2727
from codeflash.code_utils.github_utils import get_github_secrets_page_url
2828
from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc
@@ -201,7 +201,7 @@ def collect_setup_info() -> SetupInfo:
201201
path_type=inquirer.Path.DIRECTORY,
202202
)
203203
if custom_module_root_answer:
204-
module_root = Path(curdir) / Path(custom_module_root_answer["path"])
204+
module_root = Path(custom_module_root_answer["path"])
205205
else:
206206
apologize_and_exit()
207207
else:
@@ -514,7 +514,8 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n
514514
from importlib.resources import files
515515

516516
benchmark_mode = False
517-
if "benchmarks_root" in config:
517+
benchmarks_root = config.get("benchmarks_root", "").strip()
518+
if benchmarks_root and benchmarks_root != "":
518519
benchmark_mode = inquirer_wrapper(
519520
inquirer.confirm,
520521
message="⚡️It looks like you've configured a benchmarks_root in your config. Would you like to run the Github action in benchmark mode? "
@@ -537,7 +538,7 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n
537538
existing_api_key = None
538539
click.prompt(
539540
f"Next, you'll need to add your CODEFLASH_API_KEY as a secret to your GitHub repo.{LF}"
540-
f"Press Enter to open your repo's secrets page at {get_github_secrets_page_url(repo)}{LF}"
541+
f"Press Enter to open your repo's secrets page at {get_github_secrets_page_url(repo)} {LF}"
541542
f"Then, click 'New repository secret' to add your api key with the variable name CODEFLASH_API_KEY.{LF}"
542543
f"{'Here is your CODEFLASH_API_KEY: ' + existing_api_key + ' ' + LF}"
543544
if existing_api_key
@@ -720,11 +721,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
720721
)
721722
elif formatter == "don't use a formatter":
722723
formatter_cmds.append("disabled")
723-
if formatter in ["black", "ruff"]:
724-
try:
725-
subprocess.run([formatter], capture_output=True, check=False)
726-
except (FileNotFoundError, NotADirectoryError):
727-
click.echo(f"⚠️ Formatter not found: {formatter}, please ensure it is installed")
724+
check_formatter_installed(formatter_cmds, exit_on_failure=False)
728725
codeflash_section["formatter-cmds"] = formatter_cmds
729726
# Add the 'codeflash' section, ensuring 'tool' section exists
730727
tool_section = pyproject_data.get("tool", tomlkit.table())
@@ -750,7 +747,7 @@ def install_github_app() -> None:
750747

751748
else:
752749
click.prompt(
753-
f"Finally, you'll need install the Codeflash GitHub app by choosing the repository you want to install Codeflash on.{LF}"
750+
f"Finally, you'll need to install the Codeflash GitHub app by choosing the repository you want to install Codeflash on.{LF}"
754751
f"I will attempt to open the github app page - https://github.com/apps/codeflash-ai/installations/select_target {LF}"
755752
f"Press Enter to open the page to let you install the app…{LF}",
756753
default="",
@@ -924,6 +921,14 @@ def test_sort():
924921

925922

926923
def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None:
924+
try:
925+
check_formatter_installed(args.formatter_cmds)
926+
except Exception:
927+
logger.error(
928+
"Formatter not found. Review the formatter_cmds in your pyproject.toml file and make sure the formatter is installed."
929+
)
930+
return
931+
927932
command = ["codeflash", "--file", "bubble_sort.py", "--function", "sorter"]
928933
if args.no_pr:
929934
command.append("--no-pr")

codeflash/code_utils/code_utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,87 @@
22

33
import ast
44
import os
5+
import re
56
import shutil
67
import site
8+
from contextlib import contextmanager
79
from functools import lru_cache
810
from pathlib import Path
911
from tempfile import TemporaryDirectory
1012

13+
import tomlkit
14+
1115
from codeflash.cli_cmds.console import logger
16+
from codeflash.code_utils.config_parser import find_pyproject_toml
17+
18+
19+
@contextmanager
20+
def custom_addopts() -> None:
21+
pyproject_file = find_pyproject_toml()
22+
original_content = None
23+
non_blacklist_plugin_args = ""
24+
25+
try:
26+
# Read original file
27+
if pyproject_file.exists():
28+
with Path.open(pyproject_file, encoding="utf-8") as f:
29+
original_content = f.read()
30+
data = tomlkit.parse(original_content)
31+
# Backup original addopts
32+
original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "")
33+
# nothing to do if no addopts present
34+
if original_addopts != "":
35+
original_addopts = [x.strip() for x in original_addopts]
36+
non_blacklist_plugin_args = re.sub(r"-n(?: +|=)\S+", "", " ".join(original_addopts)).split(" ")
37+
non_blacklist_plugin_args = [x for x in non_blacklist_plugin_args if x != ""]
38+
if non_blacklist_plugin_args != original_addopts:
39+
data["tool"]["pytest"]["ini_options"]["addopts"] = non_blacklist_plugin_args
40+
# Write modified file
41+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
42+
f.write(tomlkit.dumps(data))
43+
44+
yield
45+
46+
finally:
47+
# Restore original file
48+
if (
49+
original_content
50+
and pyproject_file.exists()
51+
and tuple(original_addopts) not in {(), tuple(non_blacklist_plugin_args)}
52+
):
53+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
54+
f.write(original_content)
55+
56+
57+
@contextmanager
58+
def add_addopts_to_pyproject() -> None:
59+
pyproject_file = find_pyproject_toml()
60+
original_content = None
61+
try:
62+
# Read original file
63+
if pyproject_file.exists():
64+
with Path.open(pyproject_file, encoding="utf-8") as f:
65+
original_content = f.read()
66+
data = tomlkit.parse(original_content)
67+
data["tool"]["pytest"] = {}
68+
data["tool"]["pytest"]["ini_options"] = {}
69+
data["tool"]["pytest"]["ini_options"]["addopts"] = [
70+
"-n=auto",
71+
"-n",
72+
"1",
73+
"-n 1",
74+
"-n 1",
75+
"-n auto",
76+
]
77+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
78+
f.write(tomlkit.dumps(data))
79+
80+
yield
81+
82+
finally:
83+
# Restore original file
84+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
85+
f.write(original_content)
1286

1387

1488
def encoded_tokens_len(s: str) -> int:

codeflash/code_utils/env_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,37 @@
11
from __future__ import annotations
22

33
import os
4+
import sys
5+
import tempfile
46
from functools import lru_cache
7+
from pathlib import Path
58
from typing import Optional
69

710
from codeflash.cli_cmds.console import logger
11+
from codeflash.code_utils.formatter import format_code
812
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
913

1014

15+
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
16+
return_code = True
17+
if formatter_cmds[0] == "disabled":
18+
return return_code
19+
tmp_code = """print("hello world")"""
20+
with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", suffix=".py") as f:
21+
f.write(tmp_code)
22+
f.flush()
23+
tmp_file = Path(f.name)
24+
try:
25+
format_code(formatter_cmds, tmp_file)
26+
except Exception:
27+
print(
28+
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again."
29+
)
30+
if exit_on_failure:
31+
sys.exit(1)
32+
return return_code
33+
34+
1135
@lru_cache(maxsize=1)
1236
def get_codeflash_api_key() -> str:
1337
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()

codeflash/code_utils/formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def format_code(formatter_cmds: list[str], path: Path) -> str:
2222
if formatter_name == "disabled":
2323
return path.read_text(encoding="utf8")
2424
file_token = "$file" # noqa: S105
25-
for command in set(formatter_cmds):
25+
for command in formatter_cmds:
2626
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
2727
formatter_cmd_list = [path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
2828
try:

codeflash/discovery/discover_unit_tests.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pydantic.dataclasses import dataclass
1717

1818
from codeflash.cli_cmds.console import console, logger, test_files_progress_bar
19-
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
19+
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file, module_name_from_file_path
2020
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
2121
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
2222

@@ -149,19 +149,20 @@ def discover_tests_pytest(
149149
project_root = cfg.project_root_path
150150

151151
tmp_pickle_path = get_run_tmp_file("collected_tests.pkl")
152-
result = subprocess.run(
153-
[
154-
SAFE_SYS_EXECUTABLE,
155-
Path(__file__).parent / "pytest_new_process_discovery.py",
156-
str(project_root),
157-
str(tests_root),
158-
str(tmp_pickle_path),
159-
],
160-
cwd=project_root,
161-
check=False,
162-
capture_output=True,
163-
text=True,
164-
)
152+
with custom_addopts():
153+
result = subprocess.run(
154+
[
155+
SAFE_SYS_EXECUTABLE,
156+
Path(__file__).parent / "pytest_new_process_discovery.py",
157+
str(project_root),
158+
str(tests_root),
159+
str(tmp_pickle_path),
160+
],
161+
cwd=project_root,
162+
check=False,
163+
capture_output=True,
164+
text=True,
165+
)
165166
try:
166167
with tmp_pickle_path.open(mode="rb") as f:
167168
exitcode, tests, pytest_rootdir = pickle.load(f)

codeflash/optimization/function_optimizer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def determine_best_candidate(
392392
try:
393393
candidate_index = 0
394394
original_len = len(candidates)
395-
while True:
395+
while candidates:
396396
done = True if future_line_profile_results is None else future_line_profile_results.done()
397397
if done and (future_line_profile_results is not None):
398398
line_profile_results = future_line_profile_results.result()
@@ -402,13 +402,7 @@ def determine_best_candidate(
402402
f"Added results from line profiler to candidates, total candidates now: {original_len}"
403403
)
404404
future_line_profile_results = None
405-
try:
406-
candidate = candidates.popleft()
407-
except IndexError:
408-
if done:
409-
break
410-
time.sleep(0.1)
411-
continue
405+
candidate = candidates.popleft()
412406
candidate_index += 1
413407
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
414408
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
@@ -517,8 +511,17 @@ def determine_best_candidate(
517511
self.write_code_and_helpers(
518512
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
519513
)
520-
if done and not candidates:
521-
break
514+
if (not len(candidates)) and (
515+
not done
516+
): # all original candidates processed but lp results haven't been processed
517+
concurrent.futures.wait([future_line_profile_results])
518+
line_profile_results = future_line_profile_results.result()
519+
candidates.extend(line_profile_results)
520+
original_len += len(line_profile_results)
521+
logger.info(
522+
f"Added results from line profiler to candidates, total candidates now: {original_len}"
523+
)
524+
future_line_profile_results = None
522525
except KeyboardInterrupt as e:
523526
self.write_code_and_helpers(
524527
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path

codeflash/optimization/optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def run(self) -> None:
8484
console.rule()
8585
if not env_utils.ensure_codeflash_api_key():
8686
return
87+
if not env_utils.check_formatter_installed(self.args.formatter_cmds):
88+
return
8789
function_optimizer = None
8890
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
8991
num_optimizable_functions: int

codeflash/verification/test_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import TYPE_CHECKING
77

88
from codeflash.cli_cmds.console import logger
9-
from codeflash.code_utils.code_utils import get_run_tmp_file
9+
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file
1010
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
1111
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME
1212
from codeflash.code_utils.coverage_utils import prepare_coverage_files
@@ -23,8 +23,9 @@ def execute_test_subprocess(
2323
cmd_list: list[str], cwd: Path, env: dict[str, str] | None, timeout: int = 600
2424
) -> subprocess.CompletedProcess:
2525
"""Execute a subprocess with the given command list, working directory, environment variables, and timeout."""
26-
logger.debug(f"executing test run with command: {' '.join(cmd_list)}")
27-
return subprocess.run(cmd_list, capture_output=True, cwd=cwd, env=env, text=True, timeout=timeout, check=False)
26+
with custom_addopts():
27+
logger.debug(f"executing test run with command: {' '.join(cmd_list)}")
28+
return subprocess.run(cmd_list, capture_output=True, cwd=cwd, env=env, text=True, timeout=timeout, check=False)
2829

2930

3031
def run_behavioral_tests(
@@ -97,6 +98,7 @@ def run_behavioral_tests(
9798
coverage_cmd.extend(shlex.split(pytest_cmd, posix=IS_POSIX)[1:])
9899

99100
blocklist_args = [f"-p no:{plugin}" for plugin in BEHAVIORAL_BLOCKLISTED_PLUGINS if plugin != "cov"]
101+
100102
results = execute_test_subprocess(
101103
coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files,
102104
cwd=cwd,
@@ -252,7 +254,6 @@ def run_benchmarking_tests(
252254
pytest_test_env = test_env.copy()
253255
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
254256
blocklist_args = [f"-p no:{plugin}" for plugin in BENCHMARKING_BLOCKLISTED_PLUGINS]
255-
256257
results = execute_test_subprocess(
257258
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
258259
cwd=cwd,
@@ -278,7 +279,6 @@ def run_unittest_tests(
278279
log_level = ["-v"] if verbose else []
279280
files = [str(file) for file in test_file_paths]
280281
output_file = ["--output-file", str(result_file_path)]
281-
282282
results = execute_test_subprocess(
283283
unittest_cmd_list + log_level + files + output_file, cwd=cwd, env=test_env, timeout=600
284284
)

codeflash/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`.
2-
__version__ = "0.12.3"
3-
__version_tuple__ = (0, 12, 3)
2+
__version__ = "0.12.4"
3+
__version_tuple__ = (0, 12, 4)

0 commit comments

Comments
 (0)