diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 93d713402..6804ccc86 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -22,7 +22,7 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.compat import LF from codeflash.code_utils.config_parser import parse_config_file -from codeflash.code_utils.env_utils import get_codeflash_api_key +from codeflash.code_utils.env_utils import check_formatter_installed, get_codeflash_api_key from codeflash.code_utils.git_utils import get_git_remotes, get_repo_owner_and_name from codeflash.code_utils.github_utils import get_github_secrets_page_url from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc @@ -720,11 +720,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: ) elif formatter == "don't use a formatter": formatter_cmds.append("disabled") - if formatter in ["black", "ruff"]: - try: - subprocess.run([formatter], capture_output=True, check=False) - except (FileNotFoundError, NotADirectoryError): - click.echo(f"⚠️ Formatter not found: {formatter}, please ensure it is installed") + check_formatter_installed(formatter_cmds) codeflash_section["formatter-cmds"] = formatter_cmds # Add the 'codeflash' section, ensuring 'tool' section exists tool_section = pyproject_data.get("tool", tomlkit.table()) @@ -924,6 +920,14 @@ def test_sort(): def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None: + try: + check_formatter_installed(args.formatter_cmds) + except Exception: + logger.error( + "Formatter not found. Review the formatter_cmds in your pyproject.toml file and make sure the formatter is installed." + ) + return + command = ["codeflash", "--file", "bubble_sort.py", "--function", "sorter"] if args.no_pr: command.append("--no-pr") diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index 41ef89351..ee18cfa00 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -1,13 +1,52 @@ from __future__ import annotations import os +import shlex +import subprocess +import tempfile from functools import lru_cache +from pathlib import Path from typing import Optional from codeflash.cli_cmds.console import logger from codeflash.code_utils.shell_utils import read_api_key_from_shell_config +class FormatterNotFoundError(Exception): + """Exception raised when a formatter is not found.""" + + def __init__(self, formatter_cmd: str) -> None: + super().__init__(f"Formatter command not found: {formatter_cmd}") + + +def check_formatter_installed(formatter_cmds: list[str]) -> bool: + return_code = True + if formatter_cmds[0] == "disabled": + return return_code + tmp_code = """print("hello world")""" + with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", suffix=".py") as f: + f.write(tmp_code) + f.flush() + tmp_file = Path(f.name) + file_token = "$file" # noqa: S105 + for command in set(formatter_cmds): + formatter_cmd_list = shlex.split(command, posix=os.name != "nt") + formatter_cmd_list = [tmp_file.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list] + try: + result = subprocess.run(formatter_cmd_list, capture_output=True, check=False) + except (FileNotFoundError, NotADirectoryError): + return_code = False + break + if result.returncode: + return_code = False + break + tmp_file.unlink(missing_ok=True) + if not return_code: + msg = f"Error running formatter command: {command}" + raise FormatterNotFoundError(msg) + return return_code + + @lru_cache(maxsize=1) def get_codeflash_api_key() -> str: api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config() diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index ae8ece469..90c088ec6 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -82,6 +82,8 @@ def run(self) -> None: console.rule() if not env_utils.ensure_codeflash_api_key(): return + if not env_utils.check_formatter_installed(self.args.formatter_cmds): + return function_optimizer = None file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] num_optimizable_functions: int