Skip to content

Commit 291b3f9

Browse files
Merge branch 'main' into tracer-optimization
2 parents ff3222b + f628f31 commit 291b3f9

File tree

13 files changed

+318
-63
lines changed

13 files changed

+318
-63
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from codeflash.cli_cmds.console import console, logger
2323
from codeflash.code_utils.compat import LF
2424
from codeflash.code_utils.config_parser import parse_config_file
25-
from codeflash.code_utils.env_utils import get_codeflash_api_key
25+
from codeflash.code_utils.env_utils import check_formatter_installed, get_codeflash_api_key
2626
from codeflash.code_utils.git_utils import get_git_remotes, get_repo_owner_and_name
2727
from codeflash.code_utils.github_utils import get_github_secrets_page_url
2828
from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc
@@ -201,7 +201,7 @@ def collect_setup_info() -> SetupInfo:
201201
path_type=inquirer.Path.DIRECTORY,
202202
)
203203
if custom_module_root_answer:
204-
module_root = Path(curdir) / Path(custom_module_root_answer["path"])
204+
module_root = Path(custom_module_root_answer["path"])
205205
else:
206206
apologize_and_exit()
207207
else:
@@ -514,7 +514,8 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n
514514
from importlib.resources import files
515515

516516
benchmark_mode = False
517-
if "benchmarks_root" in config:
517+
benchmarks_root = config.get("benchmarks_root", "").strip()
518+
if benchmarks_root and benchmarks_root != "":
518519
benchmark_mode = inquirer_wrapper(
519520
inquirer.confirm,
520521
message="⚡️It looks like you've configured a benchmarks_root in your config. Would you like to run the Github action in benchmark mode? "
@@ -537,7 +538,7 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n
537538
existing_api_key = None
538539
click.prompt(
539540
f"Next, you'll need to add your CODEFLASH_API_KEY as a secret to your GitHub repo.{LF}"
540-
f"Press Enter to open your repo's secrets page at {get_github_secrets_page_url(repo)}{LF}"
541+
f"Press Enter to open your repo's secrets page at {get_github_secrets_page_url(repo)} {LF}"
541542
f"Then, click 'New repository secret' to add your api key with the variable name CODEFLASH_API_KEY.{LF}"
542543
f"{'Here is your CODEFLASH_API_KEY: ' + existing_api_key + ' ' + LF}"
543544
if existing_api_key
@@ -720,11 +721,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
720721
)
721722
elif formatter == "don't use a formatter":
722723
formatter_cmds.append("disabled")
723-
if formatter in ["black", "ruff"]:
724-
try:
725-
subprocess.run([formatter], capture_output=True, check=False)
726-
except (FileNotFoundError, NotADirectoryError):
727-
click.echo(f"⚠️ Formatter not found: {formatter}, please ensure it is installed")
724+
check_formatter_installed(formatter_cmds)
728725
codeflash_section["formatter-cmds"] = formatter_cmds
729726
# Add the 'codeflash' section, ensuring 'tool' section exists
730727
tool_section = pyproject_data.get("tool", tomlkit.table())
@@ -750,7 +747,7 @@ def install_github_app() -> None:
750747

751748
else:
752749
click.prompt(
753-
f"Finally, you'll need install the Codeflash GitHub app by choosing the repository you want to install Codeflash on.{LF}"
750+
f"Finally, you'll need to install the Codeflash GitHub app by choosing the repository you want to install Codeflash on.{LF}"
754751
f"I will attempt to open the github app page - https://github.com/apps/codeflash-ai/installations/select_target {LF}"
755752
f"Press Enter to open the page to let you install the app…{LF}",
756753
default="",
@@ -924,6 +921,14 @@ def test_sort():
924921

925922

926923
def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None:
924+
try:
925+
check_formatter_installed(args.formatter_cmds)
926+
except Exception:
927+
logger.error(
928+
"Formatter not found. Review the formatter_cmds in your pyproject.toml file and make sure the formatter is installed."
929+
)
930+
return
931+
927932
command = ["codeflash", "--file", "bubble_sort.py", "--function", "sorter"]
928933
if args.no_pr:
929934
command.append("--no-pr")

codeflash/code_utils/code_utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,87 @@
22

33
import ast
44
import os
5+
import re
56
import shutil
67
import site
8+
from contextlib import contextmanager
79
from functools import lru_cache
810
from pathlib import Path
911
from tempfile import TemporaryDirectory
1012

13+
import tomlkit
14+
1115
from codeflash.cli_cmds.console import logger
16+
from codeflash.code_utils.config_parser import find_pyproject_toml
17+
18+
19+
@contextmanager
20+
def custom_addopts() -> None:
21+
pyproject_file = find_pyproject_toml()
22+
original_content = None
23+
non_blacklist_plugin_args = ""
24+
25+
try:
26+
# Read original file
27+
if pyproject_file.exists():
28+
with Path.open(pyproject_file, encoding="utf-8") as f:
29+
original_content = f.read()
30+
data = tomlkit.parse(original_content)
31+
# Backup original addopts
32+
original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "")
33+
# nothing to do if no addopts present
34+
if original_addopts != "":
35+
original_addopts = [x.strip() for x in original_addopts]
36+
non_blacklist_plugin_args = re.sub(r"-n(?: +|=)\S+", "", " ".join(original_addopts)).split(" ")
37+
non_blacklist_plugin_args = [x for x in non_blacklist_plugin_args if x != ""]
38+
if non_blacklist_plugin_args != original_addopts:
39+
data["tool"]["pytest"]["ini_options"]["addopts"] = non_blacklist_plugin_args
40+
# Write modified file
41+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
42+
f.write(tomlkit.dumps(data))
43+
44+
yield
45+
46+
finally:
47+
# Restore original file
48+
if (
49+
original_content
50+
and pyproject_file.exists()
51+
and tuple(original_addopts) not in {(), tuple(non_blacklist_plugin_args)}
52+
):
53+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
54+
f.write(original_content)
55+
56+
57+
@contextmanager
58+
def add_addopts_to_pyproject() -> None:
59+
pyproject_file = find_pyproject_toml()
60+
original_content = None
61+
try:
62+
# Read original file
63+
if pyproject_file.exists():
64+
with Path.open(pyproject_file, encoding="utf-8") as f:
65+
original_content = f.read()
66+
data = tomlkit.parse(original_content)
67+
data["tool"]["pytest"] = {}
68+
data["tool"]["pytest"]["ini_options"] = {}
69+
data["tool"]["pytest"]["ini_options"]["addopts"] = [
70+
"-n=auto",
71+
"-n",
72+
"1",
73+
"-n 1",
74+
"-n 1",
75+
"-n auto",
76+
]
77+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
78+
f.write(tomlkit.dumps(data))
79+
80+
yield
81+
82+
finally:
83+
# Restore original file
84+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
85+
f.write(original_content)
1286

1387

1488
def encoded_tokens_len(s: str) -> int:

codeflash/code_utils/env_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,52 @@
11
from __future__ import annotations
22

33
import os
4+
import shlex
5+
import subprocess
6+
import tempfile
47
from functools import lru_cache
8+
from pathlib import Path
59
from typing import Optional
610

711
from codeflash.cli_cmds.console import logger
812
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
913

1014

15+
class FormatterNotFoundError(Exception):
16+
"""Exception raised when a formatter is not found."""
17+
18+
def __init__(self, formatter_cmd: str) -> None:
19+
super().__init__(f"Formatter command not found: {formatter_cmd}")
20+
21+
22+
def check_formatter_installed(formatter_cmds: list[str]) -> bool:
23+
return_code = True
24+
if formatter_cmds[0] == "disabled":
25+
return return_code
26+
tmp_code = """print("hello world")"""
27+
with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", suffix=".py") as f:
28+
f.write(tmp_code)
29+
f.flush()
30+
tmp_file = Path(f.name)
31+
file_token = "$file" # noqa: S105
32+
for command in set(formatter_cmds):
33+
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
34+
formatter_cmd_list = [tmp_file.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
35+
try:
36+
result = subprocess.run(formatter_cmd_list, capture_output=True, check=False)
37+
except (FileNotFoundError, NotADirectoryError):
38+
return_code = False
39+
break
40+
if result.returncode:
41+
return_code = False
42+
break
43+
tmp_file.unlink(missing_ok=True)
44+
if not return_code:
45+
msg = f"Error running formatter command: {command}"
46+
raise FormatterNotFoundError(msg)
47+
return return_code
48+
49+
1150
@lru_cache(maxsize=1)
1251
def get_codeflash_api_key() -> str:
1352
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()

codeflash/discovery/discover_unit_tests.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pydantic.dataclasses import dataclass
1818

1919
from codeflash.cli_cmds.console import console, logger, test_files_progress_bar
20-
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
20+
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file, module_name_from_file_path
2121
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
2222
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
2323

@@ -150,19 +150,20 @@ def discover_tests_pytest(
150150
project_root = cfg.project_root_path
151151

152152
tmp_pickle_path = get_run_tmp_file("collected_tests.pkl")
153-
result = subprocess.run(
154-
[
155-
SAFE_SYS_EXECUTABLE,
156-
Path(__file__).parent / "pytest_new_process_discovery.py",
157-
str(project_root),
158-
str(tests_root),
159-
str(tmp_pickle_path),
160-
],
161-
cwd=project_root,
162-
check=False,
163-
capture_output=True,
164-
text=True,
165-
)
153+
with custom_addopts():
154+
result = subprocess.run(
155+
[
156+
SAFE_SYS_EXECUTABLE,
157+
Path(__file__).parent / "pytest_new_process_discovery.py",
158+
str(project_root),
159+
str(tests_root),
160+
str(tmp_pickle_path),
161+
],
162+
cwd=project_root,
163+
check=False,
164+
capture_output=True,
165+
text=True,
166+
)
166167
try:
167168
with tmp_pickle_path.open(mode="rb") as f:
168169
exitcode, tests, pytest_rootdir = pickle.load(f)

codeflash/optimization/function_optimizer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def determine_best_candidate(
392392
try:
393393
candidate_index = 0
394394
original_len = len(candidates)
395-
while True:
395+
while candidates:
396396
done = True if future_line_profile_results is None else future_line_profile_results.done()
397397
if done and (future_line_profile_results is not None):
398398
line_profile_results = future_line_profile_results.result()
@@ -402,13 +402,7 @@ def determine_best_candidate(
402402
f"Added results from line profiler to candidates, total candidates now: {original_len}"
403403
)
404404
future_line_profile_results = None
405-
try:
406-
candidate = candidates.popleft()
407-
except IndexError:
408-
if done:
409-
break
410-
time.sleep(0.1)
411-
continue
405+
candidate = candidates.popleft()
412406
candidate_index += 1
413407
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
414408
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
@@ -517,8 +511,17 @@ def determine_best_candidate(
517511
self.write_code_and_helpers(
518512
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
519513
)
520-
if done and not candidates:
521-
break
514+
if (not len(candidates)) and (
515+
not done
516+
): # all original candidates processed but lp results haven't been processed
517+
concurrent.futures.wait([future_line_profile_results])
518+
line_profile_results = future_line_profile_results.result()
519+
candidates.extend(line_profile_results)
520+
original_len += len(line_profile_results)
521+
logger.info(
522+
f"Added results from line profiler to candidates, total candidates now: {original_len}"
523+
)
524+
future_line_profile_results = None
522525
except KeyboardInterrupt as e:
523526
self.write_code_and_helpers(
524527
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path

codeflash/optimization/optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def run(self) -> None:
8282
console.rule()
8383
if not env_utils.ensure_codeflash_api_key():
8484
return
85+
if not env_utils.check_formatter_installed(self.args.formatter_cmds):
86+
return
8587
function_optimizer = None
8688
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
8789
num_optimizable_functions: int

codeflash/verification/test_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import TYPE_CHECKING
77

88
from codeflash.cli_cmds.console import logger
9-
from codeflash.code_utils.code_utils import get_run_tmp_file
9+
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file
1010
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
1111
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME
1212
from codeflash.code_utils.coverage_utils import prepare_coverage_files
@@ -23,8 +23,9 @@ def execute_test_subprocess(
2323
cmd_list: list[str], cwd: Path, env: dict[str, str] | None, timeout: int = 600
2424
) -> subprocess.CompletedProcess:
2525
"""Execute a subprocess with the given command list, working directory, environment variables, and timeout."""
26-
logger.debug(f"executing test run with command: {' '.join(cmd_list)}")
27-
return subprocess.run(cmd_list, capture_output=True, cwd=cwd, env=env, text=True, timeout=timeout, check=False)
26+
with custom_addopts():
27+
logger.debug(f"executing test run with command: {' '.join(cmd_list)}")
28+
return subprocess.run(cmd_list, capture_output=True, cwd=cwd, env=env, text=True, timeout=timeout, check=False)
2829

2930

3031
def run_behavioral_tests(
@@ -97,6 +98,7 @@ def run_behavioral_tests(
9798
coverage_cmd.extend(shlex.split(pytest_cmd, posix=IS_POSIX)[1:])
9899

99100
blocklist_args = [f"-p no:{plugin}" for plugin in BEHAVIORAL_BLOCKLISTED_PLUGINS if plugin != "cov"]
101+
100102
results = execute_test_subprocess(
101103
coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files,
102104
cwd=cwd,
@@ -252,7 +254,6 @@ def run_benchmarking_tests(
252254
pytest_test_env = test_env.copy()
253255
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
254256
blocklist_args = [f"-p no:{plugin}" for plugin in BENCHMARKING_BLOCKLISTED_PLUGINS]
255-
256257
results = execute_test_subprocess(
257258
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
258259
cwd=cwd,
@@ -278,7 +279,6 @@ def run_unittest_tests(
278279
log_level = ["-v"] if verbose else []
279280
files = [str(file) for file in test_file_paths]
280281
output_file = ["--output-file", str(result_file_path)]
281-
282282
results = execute_test_subprocess(
283283
unittest_cmd_list + log_level + files + output_file, cwd=cwd, env=test_env, timeout=600
284284
)

0 commit comments

Comments
 (0)