|
4 | 4 | import shlex |
5 | 5 | import subprocess |
6 | 6 | from typing import TYPE_CHECKING, Optional |
| 7 | + |
7 | 8 | import isort |
8 | 9 |
|
9 | 10 | from codeflash.cli_cmds.console import console, logger |
10 | 11 |
|
11 | 12 | if TYPE_CHECKING: |
12 | 13 | from pathlib import Path |
13 | 14 |
|
14 | | -def get_diff_output(cmd: list[str]) -> Optional[str]: |
15 | | - try: |
16 | | - result = subprocess.run(cmd, capture_output=True, text=True, check=True) |
17 | | - return result.stdout.strip() or None |
18 | | - except (FileNotFoundError, subprocess.CalledProcessError) as e: |
19 | | - if isinstance(e, subprocess.CalledProcessError): |
20 | | - # ruff returns 1 when the file needs formatting, and 0 when it is already formatted |
21 | | - is_ruff = cmd[0] == "ruff" |
22 | | - if e.returncode == 0 and is_ruff: |
23 | | - return "" |
24 | | - if e.returncode == 1 and is_ruff: |
25 | | - return e.stdout.strip() or None |
26 | | - return None |
27 | | - |
28 | 15 |
|
29 | | -def get_diff_lines_output_by_black(filepath: str) -> Optional[str]: |
| 16 | +def get_diff_output_by_black(filepath: str, unformatted_content: str) -> Optional[str]: |
30 | 17 | try: |
31 | | - import black # type: ignore |
32 | | - return get_diff_output(['black', '--diff', filepath]) |
33 | | - except ImportError: |
34 | | - return None |
| 18 | + import black |
35 | 19 |
|
36 | | -def get_diff_lines_output_by_ruff(filepath: str) -> Optional[str]: |
37 | | - try: |
38 | | - import ruff # type: ignore |
39 | | - return get_diff_output(['ruff', 'format', '--diff', filepath]) |
| 20 | + formatted_content = black.format_file_contents(src_contents=unformatted_content, fast=True, mode=black.Mode()) |
| 21 | + return black.diff(unformatted_content, formatted_content, a_name=filepath, b_name=filepath) |
40 | 22 | except ImportError: |
41 | | - print("can't import ruff") |
42 | 23 | return None |
43 | 24 |
|
44 | 25 |
|
45 | 26 | def get_diff_lines_count(diff_output: str) -> int: |
46 | | - lines = diff_output.split('\n') |
| 27 | + lines = diff_output.split("\n") |
| 28 | + |
47 | 29 | def is_diff_line(line: str) -> bool: |
48 | | - return line.startswith(('+', '-')) and not line.startswith(('+++', '---')) |
| 30 | + return line.startswith(("+", "-")) and not line.startswith(("+++", "---")) |
| 31 | + |
49 | 32 | diff_lines = [line for line in lines if is_diff_line(line)] |
50 | 33 | return len(diff_lines) |
51 | 34 |
|
52 | | -def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: |
53 | | - diff_changes_stdout = None |
54 | 35 |
|
55 | | - diff_changes_stdout = get_diff_lines_output_by_black(filepath) |
| 36 | +def is_safe_to_format(filepath: str, content: str, max_diff_lines: int = 100) -> bool: |
| 37 | + diff_changes_str = None |
| 38 | + |
| 39 | + diff_changes_str = get_diff_output_by_black(filepath, unformatted_content=content) |
56 | 40 |
|
57 | | - if diff_changes_stdout is None: |
58 | | - logger.warning("black formatter not found, trying ruff instead...") |
59 | | - diff_changes_stdout = get_diff_lines_output_by_ruff(filepath) |
60 | | - if diff_changes_stdout is None: |
61 | | - logger.warning("Both ruff, black formatters not found, skipping formatting diff check.") |
62 | | - return False |
63 | | - |
64 | | - diff_lines_count = get_diff_lines_count(diff_changes_stdout) |
65 | | - |
| 41 | + if diff_changes_str is None: |
| 42 | + logger.warning("Looks like black formatter not found, make sure it is installed.") |
| 43 | + return False |
| 44 | + |
| 45 | + diff_lines_count = get_diff_lines_count(diff_changes_str) |
66 | 46 | if diff_lines_count > max_diff_lines: |
67 | | - logger.debug(f"Skipping {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})") |
| 47 | + logger.debug(f"Skipping formatting {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})") |
68 | 48 | return False |
69 | 49 |
|
70 | 50 | return True |
71 | | - |
| 51 | + |
72 | 52 |
|
73 | 53 | def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa |
74 | 54 | # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution |
75 | 55 | formatter_name = formatter_cmds[0].lower() |
76 | 56 | if not path.exists(): |
77 | 57 | msg = f"File {path} does not exist. Cannot format the file." |
78 | 58 | raise FileNotFoundError(msg) |
79 | | - if formatter_name == "disabled" or not is_safe_to_format(str(path)): |
80 | | - return path.read_text(encoding="utf8") |
| 59 | + file_content = path.read_text(encoding="utf8") |
| 60 | + if formatter_name == "disabled" or not is_safe_to_format(filepath=str(path), content=file_content): |
| 61 | + return file_content |
81 | 62 |
|
82 | 63 | file_token = "$file" # noqa: S105 |
83 | 64 | for command in formatter_cmds: |
|
0 commit comments