Skip to content

Commit ecb3a89

Browse files
committed
refactor tools/ to have a single run_command
Signed-off-by: Tomas Tomecek <[email protected]>
1 parent 67752b5 commit ecb3a89

File tree

3 files changed

+52
-31
lines changed

3 files changed

+52
-31
lines changed

beeai/agents/tools/commands.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from beeai_framework.emitter import Emitter
88
from beeai_framework.tools import JSONToolOutput, Tool, ToolRunOptions
99

10+
from tools.utils import run_command
11+
1012

1113
class RunShellCommandToolInput(BaseModel):
1214
command: str = Field(description="Command to run")
@@ -39,18 +41,6 @@ def _create_emitter(self) -> Emitter:
3941
async def _run(
4042
self, tool_input: RunShellCommandToolInput, options: ToolRunOptions | None, context: RunContext
4143
) -> RunShellCommandToolOutput:
42-
proc = await asyncio.create_subprocess_shell(
43-
tool_input.command,
44-
stdout=asyncio.subprocess.PIPE,
45-
stderr=asyncio.subprocess.PIPE,
46-
)
47-
48-
stdout, stderr = await proc.communicate()
49-
50-
result = {
51-
"exit_code": proc.returncode,
52-
"stdout": stdout.decode() if stdout else None,
53-
"stderr": stderr.decode() if stderr else None,
54-
}
44+
result = await run_command(tool_input.command, subprocess_function=asyncio.create_subprocess_shell)
5545

5646
return RunShellCommandToolOutput(RunShellCommandToolResult.model_validate(result))

beeai/agents/tools/utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import asyncio
2+
from pathlib import Path
3+
from typing import Callable, Optional, Union
4+
5+
6+
async def run_command(
7+
cmd: Union[list[str], str],
8+
cwd: Optional[Path] = None,
9+
subprocess_function: Callable = asyncio.create_subprocess_exec
10+
) -> dict[str, str | int]:
11+
"""
12+
Run a shell command asynchronously and capture its output.
13+
14+
Args:
15+
cmd: The command to run. Can be a list of arguments (for exec) or a string (for shell).
16+
cwd: Optional working directory to run the command in.
17+
subprocess_function: The asyncio subprocess function to use (e.g., create_subprocess_exec or create_subprocess_shell).
18+
19+
Returns:
20+
A dictionary with:
21+
- "exit_code": The process exit code (int)
22+
- "stdout": The standard output as a string, or None
23+
- "stderr": The standard error as a string, or None
24+
"""
25+
if subprocess_function is asyncio.create_subprocess_exec:
26+
proc = await subprocess_function(
27+
cmd[0],
28+
*cmd[1:],
29+
stdout=asyncio.subprocess.PIPE,
30+
stderr=asyncio.subprocess.PIPE,
31+
cwd=cwd,
32+
)
33+
else:
34+
proc = await subprocess_function(
35+
cmd,
36+
stdout=asyncio.subprocess.PIPE,
37+
stderr=asyncio.subprocess.PIPE,
38+
cwd=cwd,
39+
)
40+
41+
stdout, stderr = await proc.communicate()
42+
43+
return {
44+
"exit_code": proc.returncode,
45+
"stdout": stdout.decode() if stdout else None,
46+
"stderr": stderr.decode() if stderr else None,
47+
}

beeai/agents/tools/wicked_git.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import asyncio
21
from pathlib import Path
32

43
from pydantic import BaseModel, Field
54

5+
from tools.utils import run_command
6+
67
from beeai_framework.context import RunContext
78
from beeai_framework.emitter import Emitter
89
from beeai_framework.tools import StringToolOutput, Tool, ToolRunOptions
@@ -13,23 +14,6 @@ class GitPatchCreationToolInput(BaseModel):
1314
patch_file_path: str = Field(description="Absolute path where the patch file should be saved")
1415

1516

16-
async def run_command(cmd: list[str], cwd: Path) -> dict[str, str | int]:
17-
proc = await asyncio.create_subprocess_exec(
18-
cmd[0],
19-
*cmd[1:],
20-
stdout=asyncio.subprocess.PIPE,
21-
stderr=asyncio.subprocess.PIPE,
22-
cwd=cwd,
23-
)
24-
25-
stdout, stderr = await proc.communicate()
26-
27-
return {
28-
"exit_code": proc.returncode,
29-
"stdout": stdout.decode() if stdout else None,
30-
"stderr": stderr.decode() if stderr else None,
31-
}
32-
3317
class GitPatchCreationTool(Tool[GitPatchCreationToolInput, ToolRunOptions, StringToolOutput]):
3418
name = "git_patch_create"
3519
description = """

0 commit comments

Comments
 (0)