Skip to content

Commit fd27cb7

Browse files
author
Xu
committed
fix no file bugs
1 parent 1064b68 commit fd27cb7

File tree

3 files changed

+30
-16
lines changed

3 files changed

+30
-16
lines changed

rdagent/scenarios/data_science/dev/runner/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ def implement_one_task(
113113
queried_former_failed_knowledge=queried_former_failed_knowledge[0],
114114
)
115115
code = session.build_chat_completion(user_prompt=user_prompt)
116-
code_batch_edit = extract_output_fn(code)
116+
if self.settings.diff_mode:
117+
code_batch_edit = extract_output_fn(code, prefix=workspace.workspace_path)
118+
else:
119+
code_batch_edit = extract_output_fn(code)
117120
code_batch_edit = {k: v for k, v in code_batch_edit.items() if k in workspace.file_dict.keys()}
118121

119122
# Change Summary

rdagent/utils/agent/apply_patch.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -417,17 +417,26 @@ def text_to_patch(text: str, orig: dict[str, str]) -> tuple[Patch, int]:
417417
return parser.patch, parser.fuzz
418418

419419

420-
def identify_files_needed(text: str) -> list[str]:
420+
def identify_files_needed(text: str, prefix: str | None = None) -> list[str]:
421421
lines = text.splitlines()
422-
return [line[len("*** Update File: ") :] for line in lines if line.startswith("*** Update File: ")] + [
423-
line[len("*** Delete File: ") :] for line in lines if line.startswith("*** Delete File: ")
424-
]
422+
update_files = [line[len("*** Update File: "):] for line in lines if line.startswith("*** Update File: ")]
423+
delete_files = [line[len("*** Delete File: "):] for line in lines if line.startswith("*** Delete File: ")]
424+
all_files = update_files + delete_files
425+
426+
if prefix is None:
427+
return all_files
428+
else:
429+
return [f"{prefix}/{file}" if prefix else file for file in all_files]
425430

426431

427-
def identify_files_added(text: str) -> list[str]:
432+
def identify_files_added(text: str, prefix: str | None = None) -> list[str]:
428433
lines = text.splitlines()
429-
return [line[len("*** Add File: ") :] for line in lines if line.startswith("*** Add File: ")]
430-
434+
added_files = [line[len("*** Add File: "):] for line in lines if line.startswith("*** Add File: ")]
435+
436+
if prefix is None:
437+
return added_files
438+
else:
439+
return [f"{prefix}/{file}" if prefix else file for file in added_files]
431440

432441
# --------------------------------------------------------------------------- #
433442
# File-system helpers
@@ -468,10 +477,11 @@ def process_patch(
468477
write_fn: Callable[[str, str], None],
469478
remove_fn: Callable[[str], None],
470479
inplace: bool = False,
480+
prefix: str | None = None
471481
) -> str:
472482
if not text.startswith("*** Begin Patch"):
473483
raise DiffError("Patch text must start with *** Begin Patch")
474-
paths = identify_files_needed(text)
484+
paths = identify_files_needed(text, prefix)
475485
orig = load_files(paths, open_fn)
476486
patch, _fuzz = text_to_patch(text, orig)
477487
commit = patch_to_commit(patch, orig)
@@ -501,13 +511,13 @@ def remove_file(path: str) -> None:
501511
# --------------------------------------------------------------------------- #
502512
# CLI entry-point
503513
# --------------------------------------------------------------------------- #
504-
def apply_patch_from_text(patch_text: str, inplace: bool = False) -> str:
514+
def apply_patch_from_text(patch_text: str, inplace: bool = False, prefix: str | None = None) -> str:
505515
"""Apply patch text to filesystem, same as main() but with parameter input"""
506516
if not patch_text:
507517
raise DiffError("Patch text cannot be empty")
508518

509519
try:
510-
result = process_patch(patch_text, open_file, write_file, remove_file, inplace)
520+
result = process_patch(patch_text, open_file, write_file, remove_file, inplace, prefix)
511521
return result
512522
except DiffError as exc:
513523
raise exc

rdagent/utils/agent/ret.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,13 @@ def get_spec(cls):
9191
return T(".tpl:PythonBatchPatchOut").r()
9292

9393
@classmethod
94-
def extract_output(cls, resp: str) -> str:
94+
def extract_output(cls, resp: str, prefix: str | None = None) -> str:
95+
code_blocks = {}
9596
# Step 1: extract patch by pattern
9697
patch_pattern = re.compile(r"(\*\*\* Begin Patch\s*(.*?)\s*\*\*\* End Patch)", re.DOTALL)
97-
match = patch_pattern.search(resp)
98-
if match:
99-
resp = match.group(1).rstrip()
98+
matches = patch_pattern.findall(resp)
99+
for match in matches:
100+
code_blocks.update(apply_patch_from_text(match, inplace=False, prefix=prefix))
100101

101102
# Step 2: apply the patch, this will modify the file in place
102-
return apply_patch_from_text(resp, inplace=False)
103+
return code_blocks

0 commit comments

Comments
 (0)