diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index ec077f444..a75f35d50 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -4,6 +4,7 @@ import shlex import subprocess from typing import TYPE_CHECKING, Optional + import isort from codeflash.cli_cmds.console import console, logger @@ -11,6 +12,7 @@ if TYPE_CHECKING: from pathlib import Path + def get_diff_output(cmd: list[str]) -> Optional[str]: try: result = subprocess.run(cmd, capture_output=True, text=True, check=True) @@ -29,25 +31,32 @@ def get_diff_output(cmd: list[str]) -> Optional[str]: def get_diff_lines_output_by_black(filepath: str) -> Optional[str]: try: import black # type: ignore - return get_diff_output(['black', '--diff', filepath]) + + return get_diff_output(["black", "--diff", filepath]) except ImportError: return None + def get_diff_lines_output_by_ruff(filepath: str) -> Optional[str]: try: import ruff # type: ignore - return get_diff_output(['ruff', 'format', '--diff', filepath]) + + return get_diff_output(["ruff", "format", "--diff", filepath]) except ImportError: print("can't import ruff") return None def get_diff_lines_count(diff_output: str) -> int: - lines = diff_output.split('\n') - def is_diff_line(line: str) -> bool: - return line.startswith(('+', '-')) and not line.startswith(('+++', '---')) - diff_lines = [line for line in lines if is_diff_line(line)] - return len(diff_lines) + count = 0 + for line in diff_output.split("\n"): + # Check only the minimal needed prefixes for diff lines + if line: + first = line[0] + if (first == "+" or first == "-") and not (line.startswith("+++") or line.startswith("---")): + count += 1 + return count + def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: diff_changes_stdout = None @@ -60,15 +69,15 @@ def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: if diff_changes_stdout is None: logger.warning("Both ruff, black formatters not found, skipping formatting diff check.") return False - + diff_lines_count = get_diff_lines_count(diff_changes_stdout) - + if diff_lines_count > max_diff_lines: logger.debug(f"Skipping {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})") return False return True - + def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution