Skip to content

Commit 31808e9

Browse files
committed
patch create tool: final changes:
- add a simple and a complex test case - convert output to StringToolOutput, much simpler - wrap everything in try/except, otherwise errors are not displayed Signed-off-by: Tomas Tomecek <[email protected]>
1 parent 66ed826 commit 31808e9

File tree

2 files changed

+137
-132
lines changed

2 files changed

+137
-132
lines changed

beeai/agents/tests/unit/test_tools.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import subprocess
23
from textwrap import dedent
34

45
import pytest
@@ -7,6 +8,7 @@
78

89
from beeai_framework.middleware.trajectory import GlobalTrajectoryMiddleware
910

11+
from tools.wicked_git import GitPatchCreationTool, GitPatchCreationToolInput
1012
from tools.commands import RunShellCommandTool, RunShellCommandToolInput
1113
from tools.specfile import (
1214
AddChangelogEntryTool,
@@ -279,3 +281,70 @@ async def test_str_replace(tmp_path):
279281
"""
280282
)[1:]
281283
)
284+
285+
@pytest.mark.asyncio
286+
async def test_git_patch_creation_tool_nonexistent_repo(tmp_path):
287+
# This test checks the error message for a non-existent repo path
288+
repo_path = tmp_path / "not_a_repo"
289+
patch_file_path = tmp_path / "patch.patch"
290+
tool = GitPatchCreationTool()
291+
output = await tool.run(
292+
input=GitPatchCreationToolInput(
293+
repository_path=str(repo_path),
294+
patch_file_path=str(patch_file_path),
295+
)
296+
).middleware(GlobalTrajectoryMiddleware(pretty=True))
297+
result = output.result
298+
assert "ERROR: Repository path does not exist" in result
299+
300+
@pytest.fixture
301+
def git_repo(tmp_path):
302+
repo_path = tmp_path / "repo"
303+
repo_path.mkdir()
304+
subprocess.run(["git", "init"], cwd=repo_path, check=True)
305+
# Create a file and commit it
306+
file_path = repo_path / "file.txt"
307+
file_path.write_text("Line 1\n")
308+
subprocess.run(["git", "add", "file.txt"], cwd=repo_path, check=True)
309+
subprocess.run(["git", "commit", "-m", "Initial commit"], cwd=repo_path, check=True)
310+
file_path.write_text("Line1\nLine 2\n")
311+
subprocess.run(["git", "add", "file.txt"], cwd=repo_path, check=True)
312+
subprocess.run(["git", "commit", "-m", "Initial commit2"], cwd=repo_path, check=True)
313+
subprocess.run(["git", "branch", "line-2"], cwd=repo_path, check=True)
314+
return repo_path
315+
316+
@pytest.mark.asyncio
317+
async def test_git_patch_creation_tool_success(git_repo, tmp_path):
318+
# Simulate a git-am session by creating a new commit and then using format-patch
319+
# Create a new file and stage it
320+
subprocess.run(["git", "reset", "--hard", "HEAD~1"], cwd=git_repo, check=True)
321+
new_file = git_repo / "file.txt"
322+
new_file.write_text("Line 1\nLine 3\n")
323+
subprocess.run(["git", "add", "file.txt"], cwd=git_repo, check=True)
324+
subprocess.run(["git", "commit", "-m", "Add line 3"], cwd=git_repo, check=True)
325+
326+
patch_file = tmp_path / "patch.patch"
327+
subprocess.run(["git", "format-patch", "-1", "HEAD", "--stdout"], cwd=git_repo, check=True, stdout=patch_file.open("w"))
328+
329+
subprocess.run(["git", "switch", "line-2"], cwd=git_repo, check=True)
330+
331+
# Now apply the patch with git am
332+
# This will fail with a merge conflict, but we don't care about that
333+
subprocess.run(["git", "am", str(patch_file)], cwd=git_repo)
334+
335+
new_file.write_text("Line 1\nLine 2\nLine 3\n")
336+
337+
# Now use the tool to create a patch file from the repo
338+
tool = GitPatchCreationTool()
339+
output_patch = tmp_path / "output.patch"
340+
output = await tool.run(
341+
input=GitPatchCreationToolInput(
342+
repository_path=str(git_repo),
343+
patch_file_path=str(output_patch),
344+
)
345+
).middleware(GlobalTrajectoryMiddleware(pretty=True))
346+
result = output.result
347+
assert "Successfully created a patch file" in result
348+
assert output_patch.exists()
349+
# The patch file should contain the commit message "Add line 3"
350+
assert "Add line 3" in output_patch.read_text()

beeai/agents/tools/wicked_git.py

Lines changed: 68 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,12 @@
55

66
from beeai_framework.context import RunContext
77
from beeai_framework.emitter import Emitter
8-
from beeai_framework.tools import JSONToolOutput, Tool, ToolRunOptions
8+
from beeai_framework.tools import StringToolOutput, Tool, ToolRunOptions
99

1010

1111
class GitPatchCreationToolInput(BaseModel):
12-
repository_path: Path = Field(description="Absolute path to the git repository")
13-
patch_file_path: Path = Field(description="Absolute path where the patch file should be saved")
14-
15-
16-
class GitPatchCreationToolResult(BaseModel):
17-
success: bool = Field(description="Whether the patch creation was successful")
18-
patch_file_path: str = Field(description="Path to the created patch file")
19-
error: str | None = Field(description="Error message if patch creation failed", default=None)
20-
21-
22-
class GitPatchCreationToolOutput(JSONToolOutput[GitPatchCreationToolResult]):
23-
""" Returns a dictionary with success or error and the path to the created patch file. """
12+
repository_path: str = Field(description="Absolute path to the git repository")
13+
patch_file_path: str = Field(description="Absolute path where the patch file should be saved")
2414

2515

2616
async def run_command(cmd: list[str], cwd: Path) -> dict[str, str | int]:
@@ -40,13 +30,12 @@ async def run_command(cmd: list[str], cwd: Path) -> dict[str, str | int]:
4030
"stderr": stderr.decode() if stderr else None,
4131
}
4232

43-
class GitPatchCreationTool(Tool[GitPatchCreationToolInput, ToolRunOptions, GitPatchCreationToolOutput]):
33+
class GitPatchCreationTool(Tool[GitPatchCreationToolInput, ToolRunOptions, StringToolOutput]):
4434
name = "git_patch_create"
4535
description = """
46-
Creates a patch file from the specified git repository with an active git-am session
47-
and after you resolved all merge conflicts. The tool generates a patch file that can be
48-
applied later in the RPM build process. Returns a dictionary with success or error and
49-
the path to the created patch file.
36+
Creates a patch file from the specified git repository with an active git-am session.
37+
The tool expects you resolved all conflicts. It generates a patch file that can be
38+
applied later in the RPM build process.
5039
"""
5140
input_schema = GitPatchCreationToolInput
5241

@@ -58,117 +47,64 @@ def _create_emitter(self) -> Emitter:
5847

5948
async def _run(
6049
self, tool_input: GitPatchCreationToolInput, options: ToolRunOptions | None, context: RunContext
61-
) -> GitPatchCreationToolOutput:
62-
# Ensure the repository path exists and is a git repository
63-
if not tool_input.repository_path.exists():
64-
return GitPatchCreationToolOutput(
65-
result=GitPatchCreationToolResult(
66-
success=False,
67-
patch_file_path="",
68-
patch_content="",
69-
error=f"Repository path does not exist: {tool_input.repository_path}"
70-
)
71-
)
72-
73-
git_dir = tool_input.repository_path / ".git"
74-
if not git_dir.exists():
75-
return GitPatchCreationToolOutput(
76-
result=GitPatchCreationToolResult(
77-
success=False,
78-
patch_file_path="",
79-
patch_content="",
80-
error=f"Not a git repository: {tool_input.repository_path}"
81-
)
82-
)
83-
84-
# list all untracked files in the repository
85-
cmd = ["git", "ls-files", "--others", "--exclude-standard"]
86-
result = await run_command(cmd, cwd=tool_input.repository_path)
87-
if result["exit_code"] != 0:
88-
return GitPatchCreationToolOutput(
89-
result=GitPatchCreationToolResult(
90-
success=False,
91-
patch_file_path="",
92-
patch_content="",
93-
error=f"Git command failed: {result['stderr']}"
94-
)
95-
)
96-
untracked_files = result["stdout"].splitlines()
97-
# list staged as well since that's what the agent usually does after it resolves conflicts
98-
cmd = ["git", "diff", "--name-only", "--cached"]
99-
result = await run_command(cmd, cwd=tool_input.repository_path)
100-
if result["exit_code"] != 0:
101-
return GitPatchCreationToolOutput(
102-
result=GitPatchCreationToolResult(
103-
success=False,
104-
patch_file_path="",
105-
patch_content="",
106-
error=f"Git command failed: {result['stderr']}"
107-
)
108-
)
109-
staged_files = result["stdout"].splitlines()
110-
all_files = untracked_files + staged_files
111-
# make sure there are no *.rej files in the repository
112-
rej_files = [file for file in all_files if file.endswith(".rej")]
113-
if rej_files:
114-
return GitPatchCreationToolOutput(
115-
result=GitPatchCreationToolResult(
116-
success=False,
117-
patch_file_path="",
118-
patch_content="",
119-
error="Merge conflicts detected in the repository: "
120-
f"{tool_input.repository_path}, {rej_files}"
121-
)
122-
)
123-
124-
# git-am leaves the repository in a dirty state, so we need to stage everything
125-
# I considered to inspect the patch and only stage the files that are changed by the patch,
126-
# but the backport process could create new files or change new ones
127-
# so let's go the naive route: git add -A
128-
cmd = ["git", "add", "-A"]
129-
result = await run_command(cmd, cwd=tool_input.repository_path)
130-
if result["exit_code"] != 0:
131-
return GitPatchCreationToolOutput(
132-
result=GitPatchCreationToolResult(
133-
success=False,
134-
patch_file_path="",
135-
patch_content="",
136-
error=f"Git command failed: {result['stderr']}"
137-
)
138-
)
139-
# continue git-am process
140-
cmd = ["git", "am", "--continue"]
141-
result = await run_command(cmd, cwd=tool_input.repository_path)
142-
if result["exit_code"] != 0:
143-
return GitPatchCreationToolOutput(
144-
result=GitPatchCreationToolResult(
145-
success=False,
146-
patch_file_path="",
147-
patch_content="",
148-
error=f"git-am failed: {result['stderr']}, out={result['stdout']}"
149-
)
150-
)
151-
# good, now we should have the patch committed, so let's get the file
152-
cmd = [
153-
"git", "format-patch",
154-
"--output",
155-
str(tool_input.patch_file_path),
156-
"HEAD~1..HEAD"
157-
]
158-
result = await run_command(cmd, cwd=tool_input.repository_path)
159-
if result["exit_code"] != 0:
160-
return GitPatchCreationToolOutput(
161-
result=GitPatchCreationToolResult(
162-
success=False,
163-
patch_file_path="",
164-
patch_content="",
165-
error=f"git-format-patch failed: {result['stderr']}"
166-
)
167-
)
168-
return GitPatchCreationToolOutput(
169-
result=GitPatchCreationToolResult(
170-
success=True,
171-
patch_file_path=str(tool_input.patch_file_path),
172-
error=None
173-
)
174-
)
50+
) -> StringToolOutput:
51+
try:
52+
# Ensure the repository path exists and is a git repository
53+
tool_input_path = Path(tool_input.repository_path)
54+
if not tool_input_path.exists():
55+
return StringToolOutput(result=f"ERROR: Repository path does not exist: {tool_input_path}")
56+
57+
git_dir = tool_input_path / ".git"
58+
if not git_dir.exists():
59+
return StringToolOutput(result=f"ERROR: Not a git repository: {tool_input_path}")
60+
61+
# list all untracked files in the repository
62+
rej_candidates = []
63+
cmd = ["git", "ls-files", "--others", "--exclude-standard"]
64+
result = await run_command(cmd, cwd=tool_input_path)
65+
if result["exit_code"] != 0:
66+
return StringToolOutput(result=f"ERROR: Git command failed: {result['stderr']}")
67+
if result["stdout"]: # none means no untracked files
68+
rej_candidates.extend(result["stdout"].splitlines())
69+
# list staged as well since that's what the agent usually does after it resolves conflicts
70+
cmd = ["git", "diff", "--name-only", "--cached"]
71+
result = await run_command(cmd, cwd=tool_input_path)
72+
if result["exit_code"] != 0:
73+
return StringToolOutput(result=f"ERROR: Git command failed: {result['stderr']}")
74+
if result["stdout"]:
75+
rej_candidates.extend(result["stdout"].splitlines())
76+
if rej_candidates:
77+
# make sure there are no *.rej files in the repository
78+
rej_files = [file for file in rej_candidates if file.endswith(".rej")]
79+
if rej_files:
80+
return StringToolOutput(result=f"ERROR: Merge conflicts detected in the repository: "
81+
f"{tool_input.repository_path}, {rej_files}")
82+
83+
# git-am leaves the repository in a dirty state, so we need to stage everything
84+
# I considered to inspect the patch and only stage the files that are changed by the patch,
85+
# but the backport process could create new files or change new ones
86+
# so let's go the naive route: git add -A
87+
cmd = ["git", "add", "-A"]
88+
result = await run_command(cmd, cwd=tool_input_path)
89+
if result["exit_code"] != 0:
90+
return StringToolOutput(result=f"ERROR: Git command failed: {result['stderr']}")
91+
# continue git-am process
92+
cmd = ["git", "am", "--continue"]
93+
result = await run_command(cmd, cwd=tool_input_path)
94+
if result["exit_code"] != 0:
95+
return StringToolOutput(result=f"ERROR: git-am failed: {result['stderr']},"
96+
f" out={result['stdout']}")
97+
# good, now we should have the patch committed, so let's get the file
98+
cmd = [
99+
"git", "format-patch",
100+
"--output",
101+
tool_input.patch_file_path,
102+
"HEAD~1..HEAD"
103+
]
104+
result = await run_command(cmd, cwd=tool_input_path)
105+
if result["exit_code"] != 0:
106+
return StringToolOutput(result=f"ERROR: git-format-patch failed: {result['stderr']}")
107+
return StringToolOutput(result=f"Successfully created a patch file: {tool_input.patch_file_path}")
108+
except Exception as e:
109+
# we absolutely need to do this otherwise the error won't appear anywhere
110+
return StringToolOutput(result=f"ERROR: {e}")

0 commit comments

Comments
 (0)