Skip to content

Commit 661e4f7

Browse files
authored
feat: support wrapping downstream agents as tools (#819)
1 parent ee680a4 commit 661e4f7

File tree

3 files changed

+117
-3
lines changed

3 files changed

+117
-3
lines changed

packages/ragbits-agents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
- Support wrapping downstream agents as tools (#818)
6+
57
## 1.3.0 (2025-09-11)
68
### Changed
79

packages/ragbits-agents/src/ragbits/agents/_main.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ class Agent(
219219
def __init__(
220220
self,
221221
llm: LLM[LLMClientOptionsT],
222+
name: str | None = None,
223+
description: str | None = None,
222224
prompt: str | type[Prompt[PromptInputT, PromptOutputT]] | Prompt[PromptInputT, PromptOutputT] | None = None,
223225
*,
224226
history: ChatFormat | None = None,
@@ -232,6 +234,8 @@ def __init__(
232234
233235
Args:
234236
llm: The LLM to run the agent.
237+
name: Optional name of the agent. Used to identify the agent instance.
238+
description: Optional description of the agent.
235239
prompt: The prompt for the agent. Can be:
236240
- str: A string prompt that will be used as system message when combined with string input,
237241
or as the user message when no input is provided during run().
@@ -248,7 +252,17 @@ def __init__(
248252
self.id = uuid.uuid4().hex[:8]
249253
self.llm = llm
250254
self.prompt = prompt
251-
self.tools = [Tool.from_callable(tool) for tool in tools or []]
255+
self.name = name
256+
self.description = description
257+
self.tools = []
258+
for tool in tools or []:
259+
if isinstance(tool, tuple):
260+
agent, kwargs = tool
261+
self.tools.append(Tool.from_agent(agent, **kwargs))
262+
elif isinstance(tool, Agent):
263+
self.tools.append(Tool.from_agent(tool))
264+
else:
265+
self.tools.append(Tool.from_callable(tool))
252266
self.mcp_servers = mcp_servers or []
253267
self.history = history or []
254268
self.keep_history = keep_history
@@ -555,7 +569,10 @@ async def _get_all_tools(self) -> dict[str, Tool]:
555569
return tools_mapping
556570

557571
async def _execute_tool(
558-
self, tool_call: ToolCall, tools_mapping: dict[str, Tool], context: AgentRunContext | None = None
572+
self,
573+
tool_call: ToolCall,
574+
tools_mapping: dict[str, Tool],
575+
context: AgentRunContext | None = None,
559576
) -> ToolCallResult:
560577
if tool_call.type != "function":
561578
raise AgentToolNotSupportedError(tool_call.type)
@@ -577,10 +594,22 @@ async def _execute_tool(
577594
else tool.on_tool_call(**call_args)
578595
)
579596

597+
if isinstance(tool_output, AgentResultStreaming):
598+
async for _ in tool_output:
599+
pass
600+
601+
tool_output = {
602+
"content": tool_output.content,
603+
"metadata": tool_output.metadata,
604+
"tool_calls": tool_output.tool_calls,
605+
"usage": tool_output.usage,
606+
}
607+
580608
outputs.result = {
581609
"tool_output": tool_output,
582610
"tool_call_id": tool_call.id,
583611
}
612+
584613
except Exception as e:
585614
outputs.result = {
586615
"error": str(e),
@@ -758,3 +787,16 @@ def from_pydantic_ai(cls, pydantic_ai_agent: "PydanticAIAgent") -> Self:
758787
tools=[tool.function for _, tool in pydantic_ai_agent._function_tools.items()],
759788
mcp_servers=cast(list[MCPServer], mcp_servers),
760789
)
790+
791+
def to_tool(self, name: str | None = None, description: str | None = None) -> Tool:
792+
"""
793+
Convert the agent into a Tool instance.
794+
795+
Args:
796+
name: Optional override for the tool name.
797+
description: Optional override for the tool description.
798+
799+
Returns:
800+
Tool instance representing the agent.
801+
"""
802+
return Tool.from_agent(self, name=name or self.name, description=description or self.description)

packages/ragbits-agents/src/ragbits/agents/tool.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
from collections.abc import Callable
22
from contextlib import suppress
33
from dataclasses import dataclass
4-
from typing import Any, Literal
4+
from typing import TYPE_CHECKING, Any, Literal
55

6+
from pydantic import BaseModel
67
from typing_extensions import Self
78

9+
from ragbits.core.llms.base import LLMClientOptionsT
10+
from ragbits.core.prompt.prompt import PromptInputT, PromptOutputT
811
from ragbits.core.utils.decorators import requires_dependencies
912
from ragbits.core.utils.function_schema import convert_function_to_function_schema, get_context_variable_name
1013

14+
if TYPE_CHECKING:
15+
from ragbits.agents import Agent, AgentResultStreaming
16+
1117
with suppress(ImportError):
1218
from pydantic_ai import Tool as PydanticAITool
1319

@@ -96,5 +102,69 @@ def to_pydantic_ai(self) -> "PydanticAITool":
96102
description=self.description,
97103
)
98104

105+
@classmethod
106+
def from_agent(
107+
cls,
108+
agent: "Agent[LLMClientOptionsT, PromptInputT, PromptOutputT]",
109+
name: str | None = None,
110+
description: str | None = None,
111+
) -> "Tool":
112+
"""
113+
Wraps a downstream agent as a single tool. The tool parameters are inferred from
114+
the downstream agent's prompt input.
115+
116+
Args:
117+
agent: The downstream agent to wrap as a tool.
118+
name: Optional override for the tool name.
119+
description: Optional override for the tool description.
120+
121+
Returns:
122+
Tool instance representing the agent.
123+
"""
124+
display_name = name or agent.name or "agent"
125+
variable_name = display_name.replace(" ", "_").lower()
126+
description = description or agent.description
127+
128+
input_model_cls = getattr(agent.prompt, "input_type", None)
129+
if input_model_cls and issubclass(input_model_cls, BaseModel):
130+
fields = input_model_cls.model_fields
131+
properties = {}
132+
required = list(fields.keys())
133+
134+
for field_name in fields:
135+
param_desc = None
136+
for t in getattr(agent, "tools", []):
137+
t_params = getattr(t, "parameters", {}).get("properties", {})
138+
if field_name in t_params:
139+
param_desc = t_params[field_name].get("description")
140+
break
141+
142+
properties[field_name] = {
143+
"type": "string",
144+
"title": field_name.capitalize(),
145+
"description": param_desc,
146+
}
147+
else:
148+
properties = {"input": {"type": "string", "description": "Input for the downstream agent"}}
149+
required = ["input"]
150+
151+
parameters = {"type": "object", "properties": properties, "required": required}
152+
153+
def _on_tool_call(**kwargs: dict) -> "AgentResultStreaming":
154+
if input_model_cls and issubclass(input_model_cls, BaseModel):
155+
model_input = input_model_cls(**kwargs)
156+
else:
157+
model_input = kwargs.get("input")
158+
159+
return agent.run_streaming(model_input)
160+
161+
return cls(
162+
name=variable_name,
163+
description=description,
164+
parameters=parameters,
165+
on_tool_call=_on_tool_call,
166+
context_var_name=get_context_variable_name(agent.run),
167+
)
168+
99169

100170
ToolChoice = Literal["auto", "none", "required"] | Callable

0 commit comments

Comments
 (0)