diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index e1d269aa7..debd87106 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -4,6 +4,7 @@ import shlex import subprocess from typing import TYPE_CHECKING, Optional + import isort from codeflash.cli_cmds.console import console, logger @@ -11,12 +12,14 @@ if TYPE_CHECKING: from pathlib import Path + def get_nth_line(text: str, n: int) -> str | None: for i, line in enumerate(text.splitlines(), start=1): if i == n: return line return None + def get_diff_output(cmd: list[str]) -> Optional[str]: try: result = subprocess.run(cmd, capture_output=True, text=True, check=True) @@ -35,25 +38,33 @@ def get_diff_output(cmd: list[str]) -> Optional[str]: def get_diff_lines_output_by_black(filepath: str) -> Optional[str]: try: import black # type: ignore - return get_diff_output(['black', '--diff', filepath]) + + return get_diff_output(["black", "--diff", filepath]) except ImportError: return None + def get_diff_lines_output_by_ruff(filepath: str) -> Optional[str]: try: import ruff # type: ignore - return get_diff_output(['ruff', 'format', '--diff', filepath]) + + return get_diff_output(["ruff", "format", "--diff", filepath]) except ImportError: print("can't import ruff") return None def get_diff_lines_count(diff_output: str) -> int: - lines = diff_output.split('\n') - def is_diff_line(line: str) -> bool: - return line.startswith(('+', '-')) and not line.startswith(('+++', '---')) - diff_lines = [line for line in lines if is_diff_line(line)] - return len(diff_lines) + # Count the number of diff lines in the given diff_output string + count = 0 + for line in diff_output.split("\n"): + if line: + c = line[0] + # Check first character and avoid lines starting with '+++', '---' + if (c == "+" or c == "-") and not (line.startswith("+++") or line.startswith("---")): + count += 1 + return count + def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: diff_changes_stdout = None @@ -66,15 +77,15 @@ def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: if diff_changes_stdout is None: logger.warning("Both ruff, black formatters not found, skipping formatting diff check.") return False - + diff_lines_count = get_diff_lines_count(diff_changes_stdout) - + if diff_lines_count > max_diff_lines: logger.debug(f"Skipping {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})") return False return True - + def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution