11from __future__ import annotations
22
3+ import difflib
34from pathlib import Path
45
6+ from patchwork .logger import logger
57from patchwork .step import Step , StepStatus
68
79
8- def save_file_contents (file_path , content ):
9- """Utility function to save content to a file."""
10- with open (file_path , "w" ) as file :
10+ def save_file_contents (file_path : str | Path , content : str ) -> None :
11+ """Utility function to save content to a file.
12+
13+ Args:
14+ file_path: Path to the file to save content to (str or Path)
15+ content: Content to write to the file
16+ """
17+ path = Path (file_path )
18+ with path .open ("w" ) as file :
1119 file .write (content )
1220
1321
@@ -33,20 +41,26 @@ def handle_indent(src: list[str], target: list[str], start: int, end: int) -> li
3341
3442
3543def replace_code_in_file (
36- file_path : str ,
44+ file_path : str | Path ,
3745 start_line : int | None ,
3846 end_line : int | None ,
3947 new_code : str ,
4048) -> None :
49+ """Replace code in a file at the specified line range.
50+
51+ Args:
52+ file_path: Path to the file to modify (str or Path)
53+ start_line: Starting line number (1-based)
54+ end_line: Ending line number (1-based)
55+ new_code: New code to insert
56+ """
4157 path = Path (file_path )
4258 new_code_lines = new_code .splitlines (keepends = True )
4359 if len (new_code_lines ) > 0 and not new_code_lines [- 1 ].endswith ("\n " ):
4460 new_code_lines [- 1 ] += "\n "
4561
4662 if path .exists () and start_line is not None and end_line is not None :
47- """Replaces specified lines in a file with new code."""
4863 text = path .read_text ()
49-
5064 lines = text .splitlines (keepends = True )
5165
5266 # Insert the new code at the start line after converting it into a list of lines
@@ -55,7 +69,7 @@ def replace_code_in_file(
5569 lines = new_code_lines
5670
5771 # Save the modified contents back to the file
58- save_file_contents (file_path , "" .join (lines ))
72+ save_file_contents (path , "" .join (lines ))
5973
6074
6175class ModifyCode (Step ):
@@ -81,16 +95,53 @@ def run(self) -> dict:
8195 return dict (modified_code_files = [])
8296
8397 for code_snippet , extracted_response in sorted_list :
84- uri = code_snippet .get ("uri" )
98+ # Use Path for consistent path handling
99+ file_path = Path (code_snippet .get ("uri" , "" ))
85100 start_line = code_snippet .get ("startLine" )
86101 end_line = code_snippet .get ("endLine" )
87102 new_code = extracted_response .get ("patch" )
88103
89104 if new_code is None :
90105 continue
91106
92- replace_code_in_file (uri , start_line , end_line , new_code )
93- modified_code_file = dict (path = uri , start_line = start_line , end_line = end_line , ** extracted_response )
107+ # Get the original content for diffing
108+ diff = ""
109+ try :
110+ # Store original content in memory
111+ original_content = file_path .read_text () if file_path .exists () else ""
112+
113+ # Apply the changes
114+ replace_code_in_file (file_path , start_line , end_line , new_code )
115+
116+ # Read modified content
117+ current_content = file_path .read_text () if file_path .exists () else ""
118+
119+ # Generate unified diff
120+ fromfile = f"a/{ file_path } "
121+ tofile = f"b/{ file_path } "
122+ diff = "" .join (difflib .unified_diff (
123+ original_content .splitlines (keepends = True ),
124+ current_content .splitlines (keepends = True ),
125+ fromfile = fromfile ,
126+ tofile = tofile
127+ ))
128+
129+ if not diff and new_code : # If no diff but we have new code (new file)
130+ diff = f"+++ { file_path } \n { new_code } "
131+ except (OSError , IOError ) as e :
132+ logger .warning (f"Failed to generate diff for { file_path } : { str (e )} " )
133+ # Still proceed with the modification even if diff generation fails
134+ replace_code_in_file (file_path , start_line , end_line , new_code )
135+ diff = f"+++ { file_path } \n { new_code } " # Use new code as diff on error
136+
137+ # Create the modified code file dictionary
138+ modified_code_file = dict (
139+ path = str (file_path ),
140+ start_line = start_line ,
141+ end_line = end_line ,
142+ diff = diff ,
143+ ** extracted_response
144+ )
94145 modified_code_files .append (modified_code_file )
95146
96147 return dict (modified_code_files = modified_code_files )
0 commit comments