11from __future__ import annotations
22
3+ import difflib
34import os
5+ import re
46import shlex
7+ import shutil
58import subprocess
6- from typing import TYPE_CHECKING
9+ import tempfile
10+ from pathlib import Path
11+ from typing import Optional , Union
712
813import isort
914
1015from codeflash .cli_cmds .console import console , logger
1116
12- if TYPE_CHECKING :
13- from pathlib import Path
1417
18+ def generate_unified_diff (original : str , modified : str , from_file : str , to_file : str ) -> str :
19+ line_pattern = re .compile (r"(.*?(?:\r\n|\n|\r|$))" )
1520
16- def format_code (formatter_cmds : list [str ], path : Path , print_status : bool = True ) -> str : # noqa
21+ def split_lines (text : str ) -> list [str ]:
22+ lines = [match [0 ] for match in line_pattern .finditer (text )]
23+ if lines and lines [- 1 ] == "" :
24+ lines .pop ()
25+ return lines
26+
27+ original_lines = split_lines (original )
28+ modified_lines = split_lines (modified )
29+
30+ diff_output = []
31+ for line in difflib .unified_diff (original_lines , modified_lines , fromfile = from_file , tofile = to_file , n = 5 ):
32+ if line .endswith ("\n " ):
33+ diff_output .append (line )
34+ else :
35+ diff_output .append (line + "\n " )
36+ diff_output .append ("\\ No newline at end of file\n " )
37+
38+ return "" .join (diff_output )
39+
40+
41+ def apply_formatter_cmds (
42+ cmds : list [str ],
43+ path : Path ,
44+ test_dir_str : Optional [str ],
45+ print_status : bool , # noqa
46+ ) -> tuple [Path , str ]:
1747 # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
18- formatter_name = formatter_cmds [0 ].lower ()
48+ formatter_name = cmds [0 ].lower ()
49+ should_make_copy = False
50+ file_path = path
51+
52+ if test_dir_str :
53+ should_make_copy = True
54+ file_path = Path (test_dir_str ) / "temp.py"
55+
56+ if not cmds or formatter_name == "disabled" :
57+ return path , path .read_text (encoding = "utf8" )
58+
1959 if not path .exists ():
20- msg = f"File { path } does not exist. Cannot format the file ."
60+ msg = f"File { path } does not exist. Cannot apply formatter commands ."
2161 raise FileNotFoundError (msg )
22- if formatter_name == "disabled" :
23- return path .read_text (encoding = "utf8" )
62+
63+ if should_make_copy :
64+ shutil .copy2 (path , file_path )
65+
2466 file_token = "$file" # noqa: S105
25- for command in formatter_cmds :
67+
68+ for command in cmds :
2669 formatter_cmd_list = shlex .split (command , posix = os .name != "nt" )
27- formatter_cmd_list = [path .as_posix () if chunk == file_token else chunk for chunk in formatter_cmd_list ]
70+ formatter_cmd_list = [file_path .as_posix () if chunk == file_token else chunk for chunk in formatter_cmd_list ]
2871 try :
2972 result = subprocess .run (formatter_cmd_list , capture_output = True , check = False )
3073 if result .returncode == 0 :
3174 if print_status :
32- console .rule (f"Formatted Successfully with: { formatter_name .replace ('$file' , path .name )} " )
75+ console .rule (f"Formatted Successfully with: { command .replace ('$file' , path .name )} " )
3376 else :
3477 logger .error (f"Failed to format code with { ' ' .join (formatter_cmd_list )} " )
3578 except FileNotFoundError as e :
@@ -44,7 +87,60 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True
4487
4588 raise e from None
4689
47- return path .read_text (encoding = "utf8" )
90+ return file_path , file_path .read_text (encoding = "utf8" )
91+
92+
93+ def get_diff_lines_count (diff_output : str ) -> int :
94+ lines = diff_output .split ("\n " )
95+
96+ def is_diff_line (line : str ) -> bool :
97+ return line .startswith (("+" , "-" )) and not line .startswith (("+++" , "---" ))
98+
99+ diff_lines = [line for line in lines if is_diff_line (line )]
100+ return len (diff_lines )
101+
102+
103+ def format_code (
104+ formatter_cmds : list [str ],
105+ path : Union [str , Path ],
106+ optimized_function : str = "" ,
107+ check_diff : bool = False , # noqa
108+ print_status : bool = True , # noqa
109+ ) -> str :
110+ with tempfile .TemporaryDirectory () as test_dir_str :
111+ if isinstance (path , str ):
112+ path = Path (path )
113+
114+ original_code = path .read_text (encoding = "utf8" )
115+ original_code_lines = len (original_code .split ("\n " ))
116+
117+ if check_diff and original_code_lines > 50 :
118+ # we dont' count the formatting diff for the optimized function as it should be well-formatted
119+ original_code_without_opfunc = original_code .replace (optimized_function , "" )
120+
121+ original_temp = Path (test_dir_str ) / "original_temp.py"
122+ original_temp .write_text (original_code_without_opfunc , encoding = "utf8" )
123+
124+ formatted_temp , formatted_code = apply_formatter_cmds (
125+ formatter_cmds , original_temp , test_dir_str , print_status = False
126+ )
127+
128+ diff_output = generate_unified_diff (
129+ original_code_without_opfunc , formatted_code , from_file = str (original_temp ), to_file = str (formatted_temp )
130+ )
131+ diff_lines_count = get_diff_lines_count (diff_output )
132+
133+ max_diff_lines = min (int (original_code_lines * 0.3 ), 50 )
134+
135+ if diff_lines_count > max_diff_lines and max_diff_lines != - 1 :
136+ logger .debug (
137+ f"Skipping formatting { path } : { diff_lines_count } lines would change (max: { max_diff_lines } )"
138+ )
139+ return original_code
140+ # TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above.
141+ _ , formatted_code = apply_formatter_cmds (formatter_cmds , path , test_dir_str = None , print_status = print_status )
142+ logger .debug (f"Formatted { path } with commands: { formatter_cmds } " )
143+ return formatted_code
48144
49145
50146def sort_imports (code : str ) -> str :
0 commit comments