Skip to content

Commit 2b522d4

Browse files
committed
Some logic changes and more succint code
1 parent 5452514 commit 2b522d4

File tree

1 file changed

+38
-44
lines changed

1 file changed

+38
-44
lines changed

patchwork/steps/FixIssue/FixIssue.py

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,23 @@ def __init__(self, inputs):
107107
- Other LLM-related parameters
108108
"""
109109
super().__init__(inputs)
110-
base_path = inputs.get("base_path")
111-
# Handle base_path carefully to avoid type issues
112-
if base_path is not None:
113-
self.base_path = str(Path(str(base_path)).resolve())
114-
else:
115-
self.base_path = str(Path.cwd())
116-
110+
cwd = str(Path.cwd())
111+
original_base_path = inputs.get("base_path")
112+
113+
if original_base_path is not None:
114+
original_base_path = str(Path(str(original_base_path)).resolve())
115+
117116
# Check if we're in a git repository
118117
try:
119-
self.repo = Repo(self.base_path, search_parent_directories=True)
120-
self.is_git_repo = True
118+
self.repo = Repo(original_base_path or cwd, search_parent_directories=True)
121119
except (InvalidGitRepositoryError, Exception):
122120
self.repo = None
123-
self.is_git_repo = False
121+
122+
repo_working_dir = None
123+
if self.repo is not None:
124+
repo_working_dir = self.repo.working_dir
125+
126+
self.base_path = original_base_path or repo_working_dir or cwd
124127

125128
llm_client = AioLlmClient.create_aio_client(inputs)
126129
if llm_client is None:
@@ -152,38 +155,29 @@ def run(self):
152155
dict: Dictionary containing list of modified files with their diffs
153156
"""
154157
self.multiturn_llm_call.execute(limit=100)
158+
159+
modified_files = []
160+
cwd = Path.cwd()
155161
for tool in self.multiturn_llm_call.tool_set.values():
156-
if isinstance(tool, CodeEditTool):
157-
cwd = Path.cwd()
158-
modified_files = [file_path.relative_to(cwd) for file_path in tool.tool_records["modified_files"]]
159-
# Generate diffs for modified files
160-
modified_files_with_diffs = []
161-
162-
for file in modified_files:
163-
file_path = Path(file)
164-
modified_file = {
165-
"path": str(file),
166-
"diff": "" # Default to empty string as requested
167-
}
168-
169-
# Only try to generate git diff if we're in a git repository
170-
if self.is_git_repo and self.repo is not None:
171-
try:
172-
# Check if file exists and is tracked by git
173-
if file_path.exists():
174-
try:
175-
# Try to get the diff using git
176-
diff = self.repo.git.diff('HEAD', str(file))
177-
if diff: # Only update if we got a diff
178-
modified_file["diff"] = diff
179-
except Exception as e:
180-
# Git-specific errors (untracked files, etc) - keep empty diff
181-
logger.warning(f"Could not get git diff for {file}: {str(e)}")
182-
except Exception as e:
183-
# General file processing errors
184-
logger.warning(f"Failed to process file {file}: {str(e)}")
185-
186-
modified_files_with_diffs.append(modified_file)
187-
188-
return dict(modified_files=modified_files_with_diffs)
189-
return dict()
162+
if not isinstance(tool, CodeEditTool):
163+
continue
164+
tool_modified_files = [
165+
dict(path=str(file_path.relative_to(cwd)), diff="")
166+
for file_path in tool.tool_records["modified_files"]
167+
]
168+
modified_files.extend(tool_modified_files)
169+
170+
# Generate diffs for modified files
171+
# Only try to generate git diff if we're in a git repository
172+
if self.repo is not None:
173+
for modified_file in modified_files:
174+
file = modified_file["path"]
175+
try:
176+
# Try to get the diff using git
177+
diff = self.repo.git.diff('HEAD', file)
178+
modified_file["diff"] = diff or ""
179+
except Exception as e:
180+
# Git-specific errors (untracked files, etc) - keep empty diff
181+
logger.warning(f"Could not get git diff for {file}: {str(e)}")
182+
183+
return dict(modified_files=modified_files)

0 commit comments

Comments
 (0)