11from __future__ import annotations
22
3+ import difflib
34import os
5+ import re
46import shlex
7+ import shutil
58import subprocess
6- from typing import TYPE_CHECKING , Optional
9+ import tempfile
10+ from pathlib import Path
11+ from typing import Optional
712
813import isort
914
1015from codeflash .cli_cmds .console import console , logger
1116
12- if TYPE_CHECKING :
13- from pathlib import Path
1417
18+ def generate_unified_diff (original : str , modified : str , from_file : str , to_file : str ) -> str :
19+ line_pattern = re .compile (r"(.*?(?:\r\n|\n|\r|$))" )
1520
16- def get_diff_output_by_black (filepath : str , unformatted_content : str ) -> Optional [str ]:
17- try :
18- from black import Mode , format_file_contents , output , report
21+ def split_lines (text : str ) -> list [str ]:
22+ lines = [match [0 ] for match in line_pattern .finditer (text )]
23+ if lines and lines [- 1 ] == "" :
24+ lines .pop ()
25+ return lines
1926
20- formatted_content = format_file_contents (src_contents = unformatted_content , fast = True , mode = Mode ())
21- return output .diff (unformatted_content , formatted_content , a_name = filepath , b_name = filepath )
22- except (ImportError , report .NothingChanged ):
23- return None
27+ original_lines = split_lines (original )
28+ modified_lines = split_lines (modified )
2429
30+ diff_output = []
31+ for line in difflib .unified_diff (original_lines , modified_lines , fromfile = from_file , tofile = to_file , n = 5 ):
32+ if line .endswith ("\n " ):
33+ diff_output .append (line )
34+ else :
35+ diff_output .append (line + "\n " )
36+ diff_output .append ("\\ No newline at end of file\n " )
2537
26- def get_diff_lines_count (diff_output : str ) -> int :
27- lines = diff_output .split ("\n " )
28-
29- def is_diff_line (line : str ) -> bool :
30- return line .startswith (("+" , "-" )) and not line .startswith (("+++" , "---" ))
31-
32- diff_lines = [line for line in lines if is_diff_line (line )]
33- return len (diff_lines )
38+ return "" .join (diff_output )
3439
3540
36- def is_safe_to_format (filepath : str , content : str , max_diff_lines : int = 100 ) -> bool :
37- diff_changes_str = None
38-
39- diff_changes_str = get_diff_output_by_black (filepath , unformatted_content = content )
40-
41- if diff_changes_str is None :
42- logger .warning ("Looks like black formatter not found, make sure it is installed." )
43- return False
44-
45- diff_lines_count = get_diff_lines_count (diff_changes_str )
46- if diff_lines_count > max_diff_lines :
47- logger .debug (f"Skipping formatting { filepath } : { diff_lines_count } lines would change (max: { max_diff_lines } )" )
48- return False
41+ def apply_formatter_cmds (
42+ cmds : list [str ],
43+ path : Path ,
44+ test_dir_str : Optional [str ],
45+ print_status : bool , # noqa
46+ ) -> tuple [Path , str ]:
47+ # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
48+ formatter_name = cmds [0 ].lower ()
49+ should_make_copy = False
50+ file_path = path
4951
50- return True
52+ if test_dir_str :
53+ should_make_copy = True
54+ file_path = Path (test_dir_str ) / "temp.py"
5155
56+ if not cmds or formatter_name == "disabled" :
57+ return path , path .read_text (encoding = "utf8" )
5258
53- def format_code (formatter_cmds : list [str ], path : Path , print_status : bool = True ) -> str : # noqa
54- # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
55- formatter_name = formatter_cmds [0 ].lower ()
5659 if not path .exists ():
57- msg = f"File { path } does not exist. Cannot format the file ."
60+ msg = f"File { path } does not exist. Cannot apply formatter commands ."
5861 raise FileNotFoundError (msg )
59- file_content = path . read_text ( encoding = "utf8" )
60- if formatter_name == "disabled" or not is_safe_to_format ( filepath = str ( path ), content = file_content ) :
61- return file_content
62+
63+ if should_make_copy :
64+ shutil . copy2 ( path , file_path )
6265
6366 file_token = "$file" # noqa: S105
64- for command in formatter_cmds :
67+
68+ for command in cmds :
6569 formatter_cmd_list = shlex .split (command , posix = os .name != "nt" )
66- formatter_cmd_list = [path .as_posix () if chunk == file_token else chunk for chunk in formatter_cmd_list ]
70+ formatter_cmd_list = [file_path .as_posix () if chunk == file_token else chunk for chunk in formatter_cmd_list ]
6771 try :
6872 result = subprocess .run (formatter_cmd_list , capture_output = True , check = False )
6973 if result .returncode == 0 :
@@ -83,7 +87,45 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True
8387
8488 raise e from None
8589
86- return path .read_text (encoding = "utf8" )
90+ return file_path , file_path .read_text (encoding = "utf8" )
91+
92+
93+ def get_diff_lines_count (diff_output : str ) -> int :
94+ lines = diff_output .split ("\n " )
95+
96+ def is_diff_line (line : str ) -> bool :
97+ return line .startswith (("+" , "-" )) and not line .startswith (("+++" , "---" ))
98+
99+ diff_lines = [line for line in lines if is_diff_line (line )]
100+ return len (diff_lines )
101+
102+
103+ def format_code (formatter_cmds : list [str ], path : Path , optimized_function : str = "" , print_status : bool = True ) -> str : # noqa
104+ with tempfile .TemporaryDirectory () as test_dir_str :
105+ max_diff_lines = 100
106+
107+ original_code = path .read_text (encoding = "utf8" )
108+ # we dont' count the formatting diff for the optimized function as it should be well-formatted (if it's provided)
109+ original_code_without_opfunc = original_code .replace (optimized_function , "" )
110+
111+ original_temp = Path (test_dir_str ) / "original_temp.py"
112+ original_temp .write_text (original_code_without_opfunc , encoding = "utf8" )
113+
114+ formatted_temp , formatted_code = apply_formatter_cmds (
115+ formatter_cmds , original_temp , test_dir_str , print_status = False
116+ )
117+
118+ diff_output = generate_unified_diff (
119+ original_code_without_opfunc , formatted_code , from_file = str (original_temp ), to_file = str (formatted_temp )
120+ )
121+ diff_lines_count = get_diff_lines_count (diff_output )
122+ if diff_lines_count > max_diff_lines :
123+ logger .debug (f"Skipping formatting { path } : { diff_lines_count } lines would change (max: { max_diff_lines } )" )
124+ return original_code
125+
126+ _ , formatted_code = apply_formatter_cmds (formatter_cmds , path , test_dir_str = None , print_status = print_status )
127+ logger .debug (f"Formatted { path } with commands: { formatter_cmds } " )
128+ return formatted_code
87129
88130
89131def sort_imports (code : str ) -> str :
0 commit comments