Skip to content

Commit d2a8711

Browse files
formatting & using internal black dep
1 parent ce15022 commit d2a8711

File tree

4 files changed

+30
-50
lines changed

4 files changed

+30
-50
lines changed

codeflash/code_utils/formatter.py

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,80 +4,61 @@
44
import shlex
55
import subprocess
66
from typing import TYPE_CHECKING, Optional
7+
78
import isort
89

910
from codeflash.cli_cmds.console import console, logger
1011

1112
if TYPE_CHECKING:
1213
from pathlib import Path
1314

14-
def get_diff_output(cmd: list[str]) -> Optional[str]:
15-
try:
16-
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
17-
return result.stdout.strip() or None
18-
except (FileNotFoundError, subprocess.CalledProcessError) as e:
19-
if isinstance(e, subprocess.CalledProcessError):
20-
# ruff returns 1 when the file needs formatting, and 0 when it is already formatted
21-
is_ruff = cmd[0] == "ruff"
22-
if e.returncode == 0 and is_ruff:
23-
return ""
24-
if e.returncode == 1 and is_ruff:
25-
return e.stdout.strip() or None
26-
return None
27-
2815

29-
def get_diff_lines_output_by_black(filepath: str) -> Optional[str]:
16+
def get_diff_output_by_black(filepath: str, unformatted_content: str) -> Optional[str]:
3017
try:
31-
import black # type: ignore
32-
return get_diff_output(['black', '--diff', filepath])
33-
except ImportError:
34-
return None
18+
import black
3519

36-
def get_diff_lines_output_by_ruff(filepath: str) -> Optional[str]:
37-
try:
38-
import ruff # type: ignore
39-
return get_diff_output(['ruff', 'format', '--diff', filepath])
20+
formatted_content = black.format_file_contents(src_contents=unformatted_content, fast=True, mode=black.Mode())
21+
return black.diff(unformatted_content, formatted_content, a_name=filepath, b_name=filepath)
4022
except ImportError:
41-
print("can't import ruff")
4223
return None
4324

4425

4526
def get_diff_lines_count(diff_output: str) -> int:
46-
lines = diff_output.split('\n')
27+
lines = diff_output.split("\n")
28+
4729
def is_diff_line(line: str) -> bool:
48-
return line.startswith(('+', '-')) and not line.startswith(('+++', '---'))
30+
return line.startswith(("+", "-")) and not line.startswith(("+++", "---"))
31+
4932
diff_lines = [line for line in lines if is_diff_line(line)]
5033
return len(diff_lines)
5134

52-
def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool:
53-
diff_changes_stdout = None
5435

55-
diff_changes_stdout = get_diff_lines_output_by_black(filepath)
36+
def is_safe_to_format(filepath: str, content: str, max_diff_lines: int = 100) -> bool:
37+
diff_changes_str = None
38+
39+
diff_changes_str = get_diff_output_by_black(filepath, unformatted_content=content)
5640

57-
if diff_changes_stdout is None:
58-
logger.warning("black formatter not found, trying ruff instead...")
59-
diff_changes_stdout = get_diff_lines_output_by_ruff(filepath)
60-
if diff_changes_stdout is None:
61-
logger.warning("Both ruff, black formatters not found, skipping formatting diff check.")
62-
return False
63-
64-
diff_lines_count = get_diff_lines_count(diff_changes_stdout)
65-
41+
if diff_changes_str is None:
42+
logger.warning("Looks like black formatter not found, make sure it is installed.")
43+
return False
44+
45+
diff_lines_count = get_diff_lines_count(diff_changes_str)
6646
if diff_lines_count > max_diff_lines:
67-
logger.debug(f"Skipping {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})")
47+
logger.debug(f"Skipping formatting {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})")
6848
return False
6949

7050
return True
71-
51+
7252

7353
def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa
7454
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
7555
formatter_name = formatter_cmds[0].lower()
7656
if not path.exists():
7757
msg = f"File {path} does not exist. Cannot format the file."
7858
raise FileNotFoundError(msg)
79-
if formatter_name == "disabled" or not is_safe_to_format(str(path)):
80-
return path.read_text(encoding="utf8")
59+
file_content = path.read_text(encoding="utf8")
60+
if formatter_name == "disabled" or not is_safe_to_format(filepath=str(path), content=file_content):
61+
return file_content
8162

8263
file_token = "$file" # noqa: S105
8364
for command in formatter_cmds:

poetry.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ crosshair-tool = ">=0.0.78"
9393
coverage = ">=7.6.4"
9494
line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13
9595
platformdirs = ">=4.3.7"
96+
black = "^25.1.0"
9697
[tool.poetry.group.dev]
9798
optional = true
9899

@@ -123,7 +124,6 @@ types-pexpect = "^4.9.0.20241208"
123124
types-unidiff = "^0.7.0.20240505"
124125
uv = ">=0.6.2"
125126
pre-commit = "^4.2.0"
126-
black = "^25.1.0"
127127

128128
[tool.poetry.build]
129129
script = "codeflash/update_license_version.py"

tests/test_formatter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,12 @@ def _run_formatting_test(source_filename: str, should_content_change: bool):
259259
args=args,
260260
)
261261

262-
optimizer.reformat_code_and_helpers(
262+
content, _ = optimizer.reformat_code_and_helpers(
263263
helper_functions=[],
264264
path=target_path,
265265
original_code=optimizer.function_to_optimize_source_code,
266266
)
267267

268-
content = target_path.read_text()
269268
if should_content_change:
270269
assert content != original, f"Expected content to change for {source_filename}"
271270
else:

0 commit comments

Comments
 (0)