Skip to content

Commit 17f9291

Browse files
⚡️ Speed up function get_diff_lines_count by 19% in PR #274 (skip-formatting-for-large-diffs)
Here is a **much faster** rewrite. The biggest bottleneck was constructing the entire `diff_lines` list just to count its length. Instead, loop directly through the lines and count matching lines, avoiding extra memory and function call overhead. This also removes the small overhead of the nested function. ### Optimizations made. - **No internal list allocation:** Now iterating and counting in one pass with no extra list. - **No inner function call:** Faster, via direct string checks. - **Short-circuit on empty:** Avoids string indexing on empty lines. - **Direct char compare for '+', '-':** Faster than using tuple membership or `startswith` with a tuple. This reduces both runtime **and** memory usage by avoiding unnecessary data structures!
1 parent 822d6cc commit 17f9291

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

codeflash/code_utils/formatter.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,22 @@
44
import shlex
55
import subprocess
66
from typing import TYPE_CHECKING, Optional
7+
78
import isort
89

910
from codeflash.cli_cmds.console import console, logger
1011

1112
if TYPE_CHECKING:
1213
from pathlib import Path
1314

15+
1416
def get_nth_line(text: str, n: int) -> str | None:
1517
for i, line in enumerate(text.splitlines(), start=1):
1618
if i == n:
1719
return line
1820
return None
1921

22+
2023
def get_diff_output(cmd: list[str]) -> Optional[str]:
2124
try:
2225
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
@@ -35,25 +38,33 @@ def get_diff_output(cmd: list[str]) -> Optional[str]:
3538
def get_diff_lines_output_by_black(filepath: str) -> Optional[str]:
3639
try:
3740
import black # type: ignore
38-
return get_diff_output(['black', '--diff', filepath])
41+
42+
return get_diff_output(["black", "--diff", filepath])
3943
except ImportError:
4044
return None
4145

46+
4247
def get_diff_lines_output_by_ruff(filepath: str) -> Optional[str]:
4348
try:
4449
import ruff # type: ignore
45-
return get_diff_output(['ruff', 'format', '--diff', filepath])
50+
51+
return get_diff_output(["ruff", "format", "--diff", filepath])
4652
except ImportError:
4753
print("can't import ruff")
4854
return None
4955

5056

5157
def get_diff_lines_count(diff_output: str) -> int:
52-
lines = diff_output.split('\n')
53-
def is_diff_line(line: str) -> bool:
54-
return line.startswith(('+', '-')) and not line.startswith(('+++', '---'))
55-
diff_lines = [line for line in lines if is_diff_line(line)]
56-
return len(diff_lines)
58+
# Count the number of diff lines in the given diff_output string
59+
count = 0
60+
for line in diff_output.split("\n"):
61+
if line:
62+
c = line[0]
63+
# Check first character and avoid lines starting with '+++', '---'
64+
if (c == "+" or c == "-") and not (line.startswith("+++") or line.startswith("---")):
65+
count += 1
66+
return count
67+
5768

5869
def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool:
5970
diff_changes_stdout = None
@@ -66,15 +77,15 @@ def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool:
6677
if diff_changes_stdout is None:
6778
logger.warning("Both ruff, black formatters not found, skipping formatting diff check.")
6879
return False
69-
80+
7081
diff_lines_count = get_diff_lines_count(diff_changes_stdout)
71-
82+
7283
if diff_lines_count > max_diff_lines:
7384
logger.debug(f"Skipping {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})")
7485
return False
7586

7687
return True
77-
88+
7889

7990
def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa
8091
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution

0 commit comments

Comments
 (0)