Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/ragbits-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Support wrapping downstream agents as tools (#818)
- Add syntax sugar allowing easier Agents definition (#820)
- Support streaming-from-downstream-agents (#812)

## 1.3.0 (2025-09-11)
### Changed
Expand Down
113 changes: 88 additions & 25 deletions packages/ragbits-agents/src/ragbits/agents/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import timedelta
from inspect import iscoroutinefunction
from types import ModuleType, SimpleNamespace
from typing import Any, ClassVar, Generic, TypeVar, cast, overload
from typing import Any, ClassVar, Generic, TypeVar, Union, cast, overload

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -50,6 +50,18 @@
_Output = TypeVar("_Output")


@dataclass
class DownstreamAgentResult:
"""
Represents a streamed item from a downstream agent while executing a tool.
"""

agent_id: str
"""ID of the downstream agent."""
item: Union[str, ToolCall, ToolCallResult, "DownstreamAgentResult"]
"""The streamed item from the downstream agent."""


@dataclass
class AgentResult(Generic[PromptOutputT]):
"""
Expand Down Expand Up @@ -157,27 +169,31 @@ class AgentRunContext(BaseModel, Generic[DepsT]):
"""The usage of the agent."""


class AgentResultStreaming(AsyncIterator[str | ToolCall | ToolCallResult]):
class AgentResultStreaming(AsyncIterator[str | ToolCall | ToolCallResult | DownstreamAgentResult]):
"""
An async iterator that will collect all yielded items by LLM.generate_streaming(). This object is returned
by `run_streaming`. It can be used in an `async for` loop to process items as they arrive. After the loop completes,
all items are available under the same names as in AgentResult class.
"""

def __init__(
self, generator: AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage]
self,
generator: AsyncGenerator[
str | ToolCall | ToolCallResult | DownstreamAgentResult | SimpleNamespace | BasePrompt | Usage
],
):
self._generator = generator
self.content: str = ""
self.tool_calls: list[ToolCallResult] | None = None
self.downstream: dict[str, list[str | ToolCall | ToolCallResult]] = {}
self.metadata: dict = {}
self.history: ChatFormat
self.usage: Usage = Usage()

def __aiter__(self) -> AsyncIterator[str | ToolCall | ToolCallResult]:
def __aiter__(self) -> AsyncIterator[str | ToolCall | ToolCallResult | DownstreamAgentResult]:
return self

async def __anext__(self) -> str | ToolCall | ToolCallResult:
async def __anext__(self) -> str | ToolCall | ToolCallResult | DownstreamAgentResult:
try:
item = await self._generator.__anext__()
match item:
Expand All @@ -189,6 +205,11 @@ async def __anext__(self) -> str | ToolCall | ToolCallResult:
if self.tool_calls is None:
self.tool_calls = []
self.tool_calls.append(item)
case DownstreamAgentResult():
if item.agent_id not in self.downstream:
self.downstream[item.agent_id] = []
if isinstance(item.item, str | ToolCall | ToolCallResult):
self.downstream[item.agent_id].append(item.item)
case BasePrompt():
item.add_assistant_message(self.content)
self.history = item.chat
Expand All @@ -198,6 +219,7 @@ async def __anext__(self) -> str | ToolCall | ToolCallResult:
"content": self.content,
"metadata": self.metadata,
"tool_calls": self.tool_calls,
"downstream": self.downstream or None,
}
raise StopAsyncIteration
case Usage():
Expand Down Expand Up @@ -368,10 +390,12 @@ async def run(
break

for tool_call in response.tool_calls:
result = await self._execute_tool(tool_call=tool_call, tools_mapping=tools_mapping, context=context)
tool_calls.append(result)

prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__)
async for result in self._execute_tool(
tool_call=tool_call, tools_mapping=tools_mapping, context=context
):
if isinstance(result, ToolCallResult):
tool_calls.append(result)
prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__)

turn_count += 1
else:
Expand Down Expand Up @@ -403,6 +427,8 @@ def run_streaming(
options: AgentOptions[LLMClientOptionsT] | None = None,
context: AgentRunContext | None = None,
tool_choice: ToolChoice | None = None,
*,
stream_downstream_events: bool = True,
) -> AgentResultStreaming: ...

@overload
Expand All @@ -412,6 +438,8 @@ def run_streaming(
options: AgentOptions[LLMClientOptionsT] | None = None,
context: AgentRunContext | None = None,
tool_choice: ToolChoice | None = None,
*,
stream_downstream_events: bool = True,
) -> AgentResultStreaming: ...

def run_streaming(
Expand All @@ -420,6 +448,8 @@ def run_streaming(
options: AgentOptions[LLMClientOptionsT] | None = None,
context: AgentRunContext | None = None,
tool_choice: ToolChoice | None = None,
*,
stream_downstream_events: bool = True,
) -> AgentResultStreaming:
"""
This method returns an `AgentResultStreaming` object that can be asynchronously
Expand All @@ -434,6 +464,8 @@ def run_streaming(
- "none": do not call tool
- "required: enforce tool usage (model decides which one)
- Callable: one of provided tools
stream_downstream_events: Whether to stream events from downstream agents when
tools execute other agents. Defaults to True.

Returns:
A `StreamingResult` object for iteration and collection.
Expand All @@ -445,7 +477,13 @@ def run_streaming(
AgentInvalidPromptInputError: If the prompt/input combination is invalid.
AgentMaxTurnsExceededError: If the maximum number of turns is exceeded.
"""
generator = self._stream_internal(input, options, context, tool_choice)
generator = self._stream_internal(
input=input,
options=options,
context=context,
tool_choice=tool_choice,
stream_downstream_events=stream_downstream_events,
)
return AgentResultStreaming(generator)

async def _stream_internal(
Expand All @@ -454,7 +492,8 @@ async def _stream_internal(
options: AgentOptions[LLMClientOptionsT] | None = None,
context: AgentRunContext | None = None,
tool_choice: ToolChoice | None = None,
) -> AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage]:
stream_downstream_events: bool = True,
) -> AsyncGenerator[str | ToolCall | ToolCallResult | DownstreamAgentResult | SimpleNamespace | BasePrompt | Usage]:
if context is None:
context = AgentRunContext()

Expand All @@ -467,24 +506,34 @@ async def _stream_internal(
turn_count = 0
max_turns = merged_options.max_turns
max_turns = 10 if max_turns is NOT_GIVEN else max_turns

with trace(input=input, options=merged_options) as outputs:
while not max_turns or turn_count < max_turns:
returned_tool_call = False
self._check_token_limits(merged_options, context.usage, prompt_with_history, self.llm)

streaming_result = self.llm.generate_streaming(
prompt=prompt_with_history,
tools=[tool.to_function_schema() for tool in tools_mapping.values()],
tool_choice=tool_choice if tool_choice and turn_count == 0 else None,
options=self._get_llm_options(llm_options, merged_options, context.usage),
)

async for chunk in streaming_result:
yield chunk

if isinstance(chunk, ToolCall):
result = await self._execute_tool(tool_call=chunk, tools_mapping=tools_mapping, context=context)
yield result
prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__)
returned_tool_call = True
async for result in self._execute_tool(
tool_call=chunk,
tools_mapping=tools_mapping,
context=context,
stream_downstream_events=stream_downstream_events,
):
yield result
if isinstance(result, ToolCallResult):
prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__)
returned_tool_call = True

turn_count += 1
if streaming_result.usage:
context.usage += streaming_result.usage
Expand Down Expand Up @@ -615,10 +664,10 @@ async def _execute_tool(
tool_call: ToolCall,
tools_mapping: dict[str, Tool],
context: AgentRunContext | None = None,
) -> ToolCallResult:
stream_downstream_events: bool = True,
) -> AsyncGenerator[ToolCallResult | DownstreamAgentResult, None]:
if tool_call.type != "function":
raise AgentToolNotSupportedError(tool_call.type)

if tool_call.name not in tools_mapping:
raise AgentToolNotAvailableError(tool_call.name)

Expand All @@ -630,15 +679,29 @@ async def _execute_tool(
if tool.context_var_name:
call_args[tool.context_var_name] = context

tool_output = (
await tool.on_tool_call(**call_args)
if iscoroutinefunction(tool.on_tool_call)
else tool.on_tool_call(**call_args)
)
call_args["_stream_downstream_events"] = stream_downstream_events

try:
tool_output = (
await tool.on_tool_call(**call_args)
if iscoroutinefunction(tool.on_tool_call)
else tool.on_tool_call(**call_args)
)
except TypeError as e:
if "_stream_downstream_events" in str(e):
call_args.pop("_stream_downstream_events", None)
tool_output = (
await tool.on_tool_call(**call_args)
if iscoroutinefunction(tool.on_tool_call)
else tool.on_tool_call(**call_args)
)
else:
raise

if isinstance(tool_output, AgentResultStreaming):
async for _ in tool_output:
pass
async for downstream_item in tool_output:
if stream_downstream_events:
yield DownstreamAgentResult(agent_id=self.id, item=downstream_item)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for agent_id to be useful agents should be registered somewhere (AgentRunContext?)


tool_output = {
"content": tool_output.content,
Expand All @@ -659,7 +722,7 @@ async def _execute_tool(
}
raise AgentToolExecutionError(tool_call.name, e) from e

return ToolCallResult(
yield ToolCallResult(
id=tool_call.id,
name=tool_call.name,
arguments=tool_call.arguments,
Expand Down
5 changes: 4 additions & 1 deletion packages/ragbits-agents/src/ragbits/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,15 @@ def from_agent(
parameters = {"type": "object", "properties": properties, "required": required}

def _on_tool_call(**kwargs: dict) -> "AgentResultStreaming":
_stream_flag = kwargs.pop("_stream_downstream_events", True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it could be stored in AgentRunContext i think

stream_downstream_events = _stream_flag if isinstance(_stream_flag, bool) else True

if input_model_cls and issubclass(input_model_cls, BaseModel):
model_input = input_model_cls(**kwargs)
else:
model_input = kwargs.get("input")

return agent.run_streaming(model_input)
return agent.run_streaming(model_input, stream_downstream_events=stream_downstream_events)

return cls(
name=variable_name,
Expand Down
Loading