From 0359dd29904f12534c0065c0878d99cdc0424826 Mon Sep 17 00:00:00 2001 From: alicja Date: Thu, 18 Sep 2025 08:32:50 +0200 Subject: [PATCH 1/4] allow streaming from downstream agents --- packages/ragbits-agents/CHANGELOG.md | 1 + .../src/ragbits/agents/_main.py | 135 +++++++++++------- .../ragbits-agents/src/ragbits/agents/tool.py | 5 +- 3 files changed, 91 insertions(+), 50 deletions(-) diff --git a/packages/ragbits-agents/CHANGELOG.md b/packages/ragbits-agents/CHANGELOG.md index 2a722f747..0b16d8949 100644 --- a/packages/ragbits-agents/CHANGELOG.md +++ b/packages/ragbits-agents/CHANGELOG.md @@ -4,6 +4,7 @@ - Support wrapping downstream agents as tools (#818) - Add syntax sugar allowing easier Agents definition (#820) +- Support streaming-from-downstream-agents (#812) ## 1.3.0 (2025-09-11) ### Changed diff --git a/packages/ragbits-agents/src/ragbits/agents/_main.py b/packages/ragbits-agents/src/ragbits/agents/_main.py index 6b15a00e6..80e034d83 100644 --- a/packages/ragbits-agents/src/ragbits/agents/_main.py +++ b/packages/ragbits-agents/src/ragbits/agents/_main.py @@ -8,7 +8,7 @@ from datetime import timedelta from inspect import iscoroutinefunction from types import ModuleType, SimpleNamespace -from typing import Any, ClassVar, Generic, TypeVar, cast, overload +from typing import Any, ClassVar, Generic, TypeVar, Union, cast, overload from pydantic import ( BaseModel, @@ -50,6 +50,18 @@ _Output = TypeVar("_Output") +@dataclass +class DownstreamAgentResult: + """ + Represents a streamed item from a downstream agent while executing a tool. + """ + + agent_id: str + """ID of the downstream agent.""" + item: Union[str, ToolCall, ToolCallResult, "DownstreamAgentResult"] + """The streamed item from the downstream agent.""" + + @dataclass class AgentResult(Generic[PromptOutputT]): """ @@ -157,7 +169,7 @@ class AgentRunContext(BaseModel, Generic[DepsT]): """The usage of the agent.""" -class AgentResultStreaming(AsyncIterator[str | ToolCall | ToolCallResult]): +class AgentResultStreaming(AsyncIterator[str | ToolCall | ToolCallResult | DownstreamAgentResult]): """ An async iterator that will collect all yielded items by LLM.generate_streaming(). This object is returned by `run_streaming`. It can be used in an `async for` loop to process items as they arrive. After the loop completes, @@ -165,19 +177,23 @@ class AgentResultStreaming(AsyncIterator[str | ToolCall | ToolCallResult]): """ def __init__( - self, generator: AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage] + self, + generator: AsyncGenerator[ + str | ToolCall | ToolCallResult | DownstreamAgentResult | SimpleNamespace | BasePrompt | Usage + ], ): self._generator = generator self.content: str = "" self.tool_calls: list[ToolCallResult] | None = None + self.downstream: dict[str, list[str | ToolCall | ToolCallResult]] = {} self.metadata: dict = {} self.history: ChatFormat self.usage: Usage = Usage() - def __aiter__(self) -> AsyncIterator[str | ToolCall | ToolCallResult]: + def __aiter__(self) -> AsyncIterator[str | ToolCall | ToolCallResult | DownstreamAgentResult]: return self - async def __anext__(self) -> str | ToolCall | ToolCallResult: + async def __anext__(self) -> str | ToolCall | ToolCallResult | DownstreamAgentResult: try: item = await self._generator.__anext__() match item: @@ -189,6 +205,11 @@ async def __anext__(self) -> str | ToolCall | ToolCallResult: if self.tool_calls is None: self.tool_calls = [] self.tool_calls.append(item) + case DownstreamAgentResult(): + if item.agent_id not in self.downstream: + self.downstream[item.agent_id] = [] + if isinstance(item.item, (str, ToolCall, ToolCallResult)): + self.downstream[item.agent_id].append(item.item) case BasePrompt(): item.add_assistant_message(self.content) self.history = item.chat @@ -198,6 +219,7 @@ async def __anext__(self) -> str | ToolCall | ToolCallResult: "content": self.content, "metadata": self.metadata, "tool_calls": self.tool_calls, + "downstream": self.downstream or None, } raise StopAsyncIteration case Usage(): @@ -368,10 +390,12 @@ async def run( break for tool_call in response.tool_calls: - result = await self._execute_tool(tool_call=tool_call, tools_mapping=tools_mapping, context=context) - tool_calls.append(result) - - prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__) + async for result in self._execute_tool( + tool_call=tool_call, tools_mapping=tools_mapping, context=context + ): + if isinstance(result, ToolCallResult): + tool_calls.append(result) + prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__) turn_count += 1 else: @@ -403,6 +427,8 @@ def run_streaming( options: AgentOptions[LLMClientOptionsT] | None = None, context: AgentRunContext | None = None, tool_choice: ToolChoice | None = None, + *, + stream_downstream_events: bool = True, ) -> AgentResultStreaming: ... @overload @@ -412,6 +438,8 @@ def run_streaming( options: AgentOptions[LLMClientOptionsT] | None = None, context: AgentRunContext | None = None, tool_choice: ToolChoice | None = None, + *, + stream_downstream_events: bool = True, ) -> AgentResultStreaming: ... def run_streaming( @@ -420,32 +448,16 @@ def run_streaming( options: AgentOptions[LLMClientOptionsT] | None = None, context: AgentRunContext | None = None, tool_choice: ToolChoice | None = None, + *, + stream_downstream_events: bool = True, ) -> AgentResultStreaming: - """ - This method returns an `AgentResultStreaming` object that can be asynchronously - iterated over. After the loop completes, all items are available under the same names as in AgentResult class. - - Args: - input: The input for the agent run. - options: The options for the agent run. - context: The context for the agent run. - tool_choice: Parameter that allows to control what tool is used at first call. Can be one of: - - "auto": let model decide if tool call is needed - - "none": do not call tool - - "required: enforce tool usage (model decides which one) - - Callable: one of provided tools - - Returns: - A `StreamingResult` object for iteration and collection. - - Raises: - AgentToolDuplicateError: If the tool names are duplicated. - AgentToolNotSupportedError: If the selected tool type is not supported. - AgentToolNotAvailableError: If the selected tool is not available. - AgentInvalidPromptInputError: If the prompt/input combination is invalid. - AgentMaxTurnsExceededError: If the maximum number of turns is exceeded. - """ - generator = self._stream_internal(input, options, context, tool_choice) + generator = self._stream_internal( + input=input, + options=options, + context=context, + tool_choice=tool_choice, + stream_downstream_events=stream_downstream_events, + ) return AgentResultStreaming(generator) async def _stream_internal( @@ -454,7 +466,8 @@ async def _stream_internal( options: AgentOptions[LLMClientOptionsT] | None = None, context: AgentRunContext | None = None, tool_choice: ToolChoice | None = None, - ) -> AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage]: + stream_downstream_events: bool = True, + ) -> AsyncGenerator[str | ToolCall | ToolCallResult | DownstreamAgentResult | SimpleNamespace | BasePrompt | Usage]: if context is None: context = AgentRunContext() @@ -467,24 +480,34 @@ async def _stream_internal( turn_count = 0 max_turns = merged_options.max_turns max_turns = 10 if max_turns is NOT_GIVEN else max_turns + with trace(input=input, options=merged_options) as outputs: while not max_turns or turn_count < max_turns: returned_tool_call = False self._check_token_limits(merged_options, context.usage, prompt_with_history, self.llm) + streaming_result = self.llm.generate_streaming( prompt=prompt_with_history, tools=[tool.to_function_schema() for tool in tools_mapping.values()], tool_choice=tool_choice if tool_choice and turn_count == 0 else None, options=self._get_llm_options(llm_options, merged_options, context.usage), ) + async for chunk in streaming_result: yield chunk if isinstance(chunk, ToolCall): - result = await self._execute_tool(tool_call=chunk, tools_mapping=tools_mapping, context=context) - yield result - prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__) - returned_tool_call = True + async for result in self._execute_tool( + tool_call=chunk, + tools_mapping=tools_mapping, + context=context, + stream_downstream_events=stream_downstream_events, + ): + yield result + if isinstance(result, ToolCallResult): + prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__) + returned_tool_call = True + turn_count += 1 if streaming_result.usage: context.usage += streaming_result.usage @@ -615,10 +638,10 @@ async def _execute_tool( tool_call: ToolCall, tools_mapping: dict[str, Tool], context: AgentRunContext | None = None, - ) -> ToolCallResult: + stream_downstream_events: bool = True, + ) -> AsyncGenerator[ToolCallResult | DownstreamAgentResult, None]: if tool_call.type != "function": raise AgentToolNotSupportedError(tool_call.type) - if tool_call.name not in tools_mapping: raise AgentToolNotAvailableError(tool_call.name) @@ -630,15 +653,29 @@ async def _execute_tool( if tool.context_var_name: call_args[tool.context_var_name] = context - tool_output = ( - await tool.on_tool_call(**call_args) - if iscoroutinefunction(tool.on_tool_call) - else tool.on_tool_call(**call_args) - ) + call_args["_stream_downstream_events"] = stream_downstream_events + + try: + tool_output = ( + await tool.on_tool_call(**call_args) + if iscoroutinefunction(tool.on_tool_call) + else tool.on_tool_call(**call_args) + ) + except TypeError as e: + if "_stream_downstream_events" in str(e): + call_args.pop("_stream_downstream_events", None) + tool_output = ( + await tool.on_tool_call(**call_args) + if iscoroutinefunction(tool.on_tool_call) + else tool.on_tool_call(**call_args) + ) + else: + raise if isinstance(tool_output, AgentResultStreaming): - async for _ in tool_output: - pass + async for downstream_item in tool_output: + if stream_downstream_events: + yield DownstreamAgentResult(agent_id=self.id, item=downstream_item) tool_output = { "content": tool_output.content, @@ -659,7 +696,7 @@ async def _execute_tool( } raise AgentToolExecutionError(tool_call.name, e) from e - return ToolCallResult( + yield ToolCallResult( id=tool_call.id, name=tool_call.name, arguments=tool_call.arguments, diff --git a/packages/ragbits-agents/src/ragbits/agents/tool.py b/packages/ragbits-agents/src/ragbits/agents/tool.py index 3da403782..e75f5f117 100644 --- a/packages/ragbits-agents/src/ragbits/agents/tool.py +++ b/packages/ragbits-agents/src/ragbits/agents/tool.py @@ -151,12 +151,15 @@ def from_agent( parameters = {"type": "object", "properties": properties, "required": required} def _on_tool_call(**kwargs: dict) -> "AgentResultStreaming": + _stream_flag = kwargs.pop("_stream_downstream_events", True) + stream_downstream_events = _stream_flag if isinstance(_stream_flag, bool) else True + if input_model_cls and issubclass(input_model_cls, BaseModel): model_input = input_model_cls(**kwargs) else: model_input = kwargs.get("input") - return agent.run_streaming(model_input) + return agent.run_streaming(model_input, stream_downstream_events=stream_downstream_events) return cls( name=variable_name, From 09a9ee2041d872d79f758fd3271b9ca828589528 Mon Sep 17 00:00:00 2001 From: alicja Date: Thu, 18 Sep 2025 08:38:02 +0200 Subject: [PATCH 2/4] fix docstring --- .../src/ragbits/agents/_main.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/packages/ragbits-agents/src/ragbits/agents/_main.py b/packages/ragbits-agents/src/ragbits/agents/_main.py index 80e034d83..e83d72a58 100644 --- a/packages/ragbits-agents/src/ragbits/agents/_main.py +++ b/packages/ragbits-agents/src/ragbits/agents/_main.py @@ -451,6 +451,29 @@ def run_streaming( *, stream_downstream_events: bool = True, ) -> AgentResultStreaming: + """ + This method returns an `AgentResultStreaming` object that can be asynchronously + iterated over. After the loop completes, all items are available under the same names as in AgentResult class. + Args: + input: The input for the agent run. + options: The options for the agent run. + context: The context for the agent run. + tool_choice: Parameter that allows to control what tool is used at first call. Can be one of: + - "auto": let model decide if tool call is needed + - "none": do not call tool + - "required: enforce tool usage (model decides which one) + - Callable: one of provided tools + stream_downstream_events: Whether to stream events from downstream agents when + tools execute other agents. Defaults to True. + Returns: + A `StreamingResult` object for iteration and collection. + Raises: + AgentToolDuplicateError: If the tool names are duplicated. + AgentToolNotSupportedError: If the selected tool type is not supported. + AgentToolNotAvailableError: If the selected tool is not available. + AgentInvalidPromptInputError: If the prompt/input combination is invalid. + AgentMaxTurnsExceededError: If the maximum number of turns is exceeded. + """ generator = self._stream_internal( input=input, options=options, From 54ac51e690ee4d17cfc6ed08e5b2ba2d2527b34f Mon Sep 17 00:00:00 2001 From: alicja Date: Thu, 18 Sep 2025 08:41:32 +0200 Subject: [PATCH 3/4] fix ruff --- packages/ragbits-agents/src/ragbits/agents/_main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/ragbits-agents/src/ragbits/agents/_main.py b/packages/ragbits-agents/src/ragbits/agents/_main.py index e83d72a58..5e9ae62bb 100644 --- a/packages/ragbits-agents/src/ragbits/agents/_main.py +++ b/packages/ragbits-agents/src/ragbits/agents/_main.py @@ -454,6 +454,7 @@ def run_streaming( """ This method returns an `AgentResultStreaming` object that can be asynchronously iterated over. After the loop completes, all items are available under the same names as in AgentResult class. + Args: input: The input for the agent run. options: The options for the agent run. @@ -465,8 +466,10 @@ def run_streaming( - Callable: one of provided tools stream_downstream_events: Whether to stream events from downstream agents when tools execute other agents. Defaults to True. + Returns: A `StreamingResult` object for iteration and collection. + Raises: AgentToolDuplicateError: If the tool names are duplicated. AgentToolNotSupportedError: If the selected tool type is not supported. From f4c5c23c16f602638ee2adcb3f7c8e25eaca6d61 Mon Sep 17 00:00:00 2001 From: alicja Date: Thu, 18 Sep 2025 08:45:27 +0200 Subject: [PATCH 4/4] fix ruff --- packages/ragbits-agents/src/ragbits/agents/_main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/ragbits-agents/src/ragbits/agents/_main.py b/packages/ragbits-agents/src/ragbits/agents/_main.py index 5e9ae62bb..60532daab 100644 --- a/packages/ragbits-agents/src/ragbits/agents/_main.py +++ b/packages/ragbits-agents/src/ragbits/agents/_main.py @@ -208,7 +208,7 @@ async def __anext__(self) -> str | ToolCall | ToolCallResult | DownstreamAgentRe case DownstreamAgentResult(): if item.agent_id not in self.downstream: self.downstream[item.agent_id] = [] - if isinstance(item.item, (str, ToolCall, ToolCallResult)): + if isinstance(item.item, str | ToolCall | ToolCallResult): self.downstream[item.agent_id].append(item.item) case BasePrompt(): item.add_assistant_message(self.content)