diff --git a/docs/api_reference/agents/index.md b/docs/api_reference/agents/index.md index ec2a0d655..a1d59bd0f 100644 --- a/docs/api_reference/agents/index.md +++ b/docs/api_reference/agents/index.md @@ -11,3 +11,5 @@ ::: ragbits.agents.a2a.server.create_agent_server ::: ragbits.agents.post_processors.base + +::: ragbits.agents.AgentRunContext diff --git a/docs/how-to/agents/stream_downstream_agents.md b/docs/how-to/agents/stream_downstream_agents.md new file mode 100644 index 000000000..1cbaccb62 --- /dev/null +++ b/docs/how-to/agents/stream_downstream_agents.md @@ -0,0 +1,48 @@ +# How-To: Stream downstream agents with Ragbits + +Ragbits [Agent][ragbits.agents.Agent] can call other agents as tools, creating a chain of reasoning where downstream agents provide structured results to the parent agent. + +Using the streaming API, you can observe every chunk of output as it is generated, including tool calls, tool results, and final text - perfect for real-time monitoring or chat interfaces. + +## Define a simple tool + +A tool is just a Python function returning a JSON-serializable result. Here’s an example tool returning the current time for a given location: + +```python +import json + +--8<-- "examples/agents/downstream_agents_streaming.py:33:51" +``` + +## Create a downstream agent + +The downstream agent wraps the tool with a prompt, allowing the LLM to use it as a function. + +```python +from pydantic import BaseModel +from ragbits.core.prompt import Prompt +from ragbits.agents import Agent +from ragbits.agents._main import AgentOptions +from ragbits.core.llms import LiteLLM + +--8<-- "examples/agents/downstream_agents_streaming.py:54:82" +``` + +## Create a parent QA agent + +The parent agent can call downstream agents as tools. This lets the LLM reason and decide when to invoke the downstream agent. + +```python +--8<-- "examples/agents/downstream_agents_streaming.py:85:111" +``` + +## Streaming output from downstream agents + +Use `run_streaming` with an [AgentRunContext][ragbits.agents.AgentRunContext] to see output as it happens. Each chunk contains either text, a tool call, or a tool result. You can print agent names when they change and handle downstream agent events. + +```python +import asyncio +from ragbits.agents import DownstreamAgentResult + +--8<-- "examples/agents/downstream_agents_streaming.py:114:133" +``` diff --git a/examples/README.md b/examples/README.md index a5291f3a0..e8fdbcb7b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -39,7 +39,8 @@ All necessary details are provided in the comments at the top of each script. | [Recontextualize Last Message](/examples/chat/recontextualize_message.py) | [ragbits-chat](/packages/ragbits-chat) | Example of how to use the `StandaloneMessageCompressor` compressor to recontextualize the last message in a conversation history. | | [Agents Tool Use](/examples/agents/tool_use.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use agent with tools. | | [Agents OpenAI Native Tool Use](/examples/agents/openai_native_tool_use.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use agent with OpenAI native tools. | -| [Agents Post Processors](/examples/agents/post_processors.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use post-processors with agent. | +| [Agents Post Processors](/examples/agents/post_processors.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use post-processors with agent. +| [Agents Downstream Streaming](/examples/agents/downstream_agents_streaming.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to stream outputs from downstream agents in real time. | | | [Agents CLI](/examples/agents/cli_agent.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use agent in CLI. | | [MCP Local](/examples/agents/mcp/local.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use the `Agent` class to connect with a local MCP server. | | [MCP SSE](/examples/agents/mcp/sse.py) | [ragbits-agents](/packages/ragbits-agents) | Example of how to use the `Agent` class to connect with a remote MCP server via SSE. | diff --git a/examples/agents/downstream_agents_streaming.py b/examples/agents/downstream_agents_streaming.py new file mode 100644 index 000000000..b68fb8dc3 --- /dev/null +++ b/examples/agents/downstream_agents_streaming.py @@ -0,0 +1,133 @@ +""" +Ragbits Agents Example: Multi-agent setup (QA agent + Time agent) + +This example demonstrates how to build a setup with two agents: +1. A Time Agent that returns the current time for a given location. +2. A QA Agent that answers user questions and can delegate to the Time Agent. + +To run the script, execute the following command: + + ```bash + uv run examples/agents/downstream_agents_streaming.py + ``` +""" + +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "ragbits-core", +# "ragbits-agents", +# ] +# /// + +import asyncio +import json + +from pydantic import BaseModel + +from ragbits.agents import Agent, AgentOptions, AgentRunContext, DownstreamAgentResult +from ragbits.core.llms import LiteLLM +from ragbits.core.prompt import Prompt + + +def get_time(location: str) -> str: + """ + Returns the current time for a given location. + + Args: + location: The location to get the time for. + + Returns: + The current time for the given location. + """ + loc = location.lower() + if "tokyo" in loc: + return json.dumps({"location": "Tokyo", "time": "10:00 AM"}) + elif "paris" in loc: + return json.dumps({"location": "Paris", "time": "04:00 PM"}) + elif "san francisco" in loc: + return json.dumps({"location": "San Francisco", "time": "07:00 PM"}) + else: + return json.dumps({"location": location, "time": "unknown"}) + + +class TimePromptInput(BaseModel): + """Input schema for the TimePrompt, containing the target location.""" + + location: str + + +class TimePrompt(Prompt[TimePromptInput]): + """ + Provides instructions for generating the current time in a user-specified + location. + """ + + system_prompt = """ + You are a helpful assistant that tells the current time in a given city. + """ + user_prompt = """ + What time is it in {{ location }}? + """ + + +llm = LiteLLM(model_name="gpt-4o-2024-08-06", use_structured_output=True) +time_agent = Agent( + name="time_agent", + description="Returns current time for a given location", + llm=llm, + prompt=TimePrompt, + tools=[get_time], + default_options=AgentOptions(max_total_tokens=1000, max_turns=5), +) + + +class QAPromptInput(BaseModel): + """Input schema for the QA agent, containing a natural-language question.""" + + question: str + + +class QAPrompt(Prompt[QAPromptInput]): + """ + Guides the agent to respond to user questions. + """ + + system_prompt = """ + You are a helpful assistant that responds to user questions. + """ + user_prompt = """ + {{ question }}. + """ + + +llm = LiteLLM(model_name="gpt-4o-2024-08-06", use_structured_output=True) +qa_agent = Agent( + name="qa_agent", + llm=llm, + prompt=QAPrompt, + tools=[(time_agent, {"name": "time_agent"})], + default_options=AgentOptions(max_total_tokens=1000, max_turns=5), +) + + +async def main() -> None: + """ + Run the QA agent with downstream streaming enabled. + + The QA agent processes a sample question ("What time is it in Paris?") and delegates to + the Time Agent when necessary. Streamed results from both agents are printed in real time, + tagged by the agent that produced them. + """ + context = AgentRunContext(stream_downstream_events=True) + + async for chunk in qa_agent.run_streaming(QAPromptInput(question="What time is it in Paris?"), context=context): + if isinstance(chunk, DownstreamAgentResult): + agent_name = context.get_agent(chunk.agent_id).name + print(f"[{agent_name}] {chunk.item}") + else: + print(f"[{qa_agent.name}] {chunk}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/mkdocs.yml b/mkdocs.yml index 8c534094d..a62180d7c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -41,6 +41,7 @@ nav: - "Provide tools with MCP": how-to/agents/provide_mcp_tools.md - "Serve agents with A2A": how-to/agents/serve_ragbits_agents.md - "Use post-processors": how-to/agents/use_post_processors.md + - "Stream downstream agents": how-to/agents/stream_downstream_agents.md - Guardrails: - "Setup guardrails": how-to/guardrails/use_guardrails.md - Chatbots: diff --git a/packages/ragbits-agents/CHANGELOG.md b/packages/ragbits-agents/CHANGELOG.md index c0eef472b..3d0e2b2a3 100644 --- a/packages/ragbits-agents/CHANGELOG.md +++ b/packages/ragbits-agents/CHANGELOG.md @@ -5,6 +5,7 @@ - Support wrapping downstream agents as tools (#818) - Add syntax sugar allowing easier Agents definition (#820) - Add post-processors (#821) +- Support streaming from downstream agents (#812) ## 1.3.0 (2025-09-11) ### Changed diff --git a/packages/ragbits-agents/src/ragbits/agents/__init__.py b/packages/ragbits-agents/src/ragbits/agents/__init__.py index bd1809337..edd8226cf 100644 --- a/packages/ragbits-agents/src/ragbits/agents/__init__.py +++ b/packages/ragbits-agents/src/ragbits/agents/__init__.py @@ -5,6 +5,8 @@ AgentResult, AgentResultStreaming, AgentRunContext, + DownstreamAgentResult, + ToolCall, ToolCallResult, ) from ragbits.agents.post_processors.base import PostProcessor, StreamingPostProcessor @@ -17,10 +19,12 @@ "AgentResult", "AgentResultStreaming", "AgentRunContext", + "DownstreamAgentResult", "PostProcessor", "QuestionAnswerAgent", "QuestionAnswerPromptInput", "QuestionAnswerPromptOutput", "StreamingPostProcessor", + "ToolCall", "ToolCallResult", ] diff --git a/packages/ragbits-agents/src/ragbits/agents/_main.py b/packages/ragbits-agents/src/ragbits/agents/_main.py index b98b06245..e7c347100 100644 --- a/packages/ragbits-agents/src/ragbits/agents/_main.py +++ b/packages/ragbits-agents/src/ragbits/agents/_main.py @@ -2,13 +2,14 @@ import types import uuid from collections.abc import AsyncGenerator, AsyncIterator, Callable +from collections.abc import AsyncGenerator as _AG from contextlib import suppress from copy import deepcopy from dataclasses import dataclass from datetime import timedelta from inspect import iscoroutinefunction from types import ModuleType, SimpleNamespace -from typing import Any, ClassVar, Generic, Literal, TypeVar, cast, overload +from typing import Any, ClassVar, Generic, Literal, TypeVar, Union, cast, overload from pydantic import ( BaseModel, @@ -57,6 +58,26 @@ _Output = TypeVar("_Output") +@dataclass +class DownstreamAgentResult: + """ + Represents a streamed item from a downstream agent while executing a tool. + """ + + agent_id: str | None + """ID of the downstream agent.""" + item: Union[ + str, + ToolCall, + ToolCallResult, + "DownstreamAgentResult", + BasePrompt, + Usage, + SimpleNamespace, + ] + """The streamed item from the downstream agent.""" + + @dataclass class AgentResult(Generic[PromptOutputT]): """ @@ -158,13 +179,42 @@ def __contains__(self, key: str) -> bool: class AgentRunContext(BaseModel, Generic[DepsT]): """Context for the agent run.""" + model_config = {"arbitrary_types_allowed": True} + deps: AgentDependencies[DepsT] = Field(default_factory=lambda: AgentDependencies()) """Container for external dependencies.""" usage: Usage = Field(default_factory=Usage) """The usage of the agent.""" + stream_downstream_events: bool = False + """Whether to stream events from downstream agents when tools execute other agents.""" + downstream_agents: dict[str, "Agent"] = Field(default_factory=dict) + """Registry of all agents that participated in this run""" + + def register_agent(self, agent: "Agent") -> None: + """ + Register a downstream agent in this context. + + Args: + agent: The agent instance to register. + """ + self.downstream_agents[agent.id] = agent + + def get_agent(self, agent_id: str) -> "Agent | None": + """ + Retrieve a registered downstream agent by its ID. + + Args: + agent_id: The unique identifier of the agent. + + Returns: + The Agent instance if found, otherwise None. + """ + return self.downstream_agents.get(agent_id) -class AgentResultStreaming(AsyncIterator[str | ToolCall | ToolCallResult | BasePrompt | Usage | SimpleNamespace]): +class AgentResultStreaming( + AsyncIterator[str | ToolCall | ToolCallResult | BasePrompt | Usage | SimpleNamespace | DownstreamAgentResult] +): """ An async iterator that will collect all yielded items by LLM.generate_streaming(). This object is returned by `run_streaming`. It can be used in an `async for` loop to process items as they arrive. After the loop completes, @@ -172,49 +222,61 @@ class AgentResultStreaming(AsyncIterator[str | ToolCall | ToolCallResult | BaseP """ def __init__( - self, generator: AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage] + self, + generator: AsyncGenerator[ + str | ToolCall | ToolCallResult | DownstreamAgentResult | SimpleNamespace | BasePrompt | Usage + ], ): self._generator = generator self.content: str = "" self.tool_calls: list[ToolCallResult] | None = None + self.downstream: dict[str | None, list[str | ToolCall | ToolCallResult]] = {} self.metadata: dict = {} self.history: ChatFormat self.usage: Usage = Usage() - def __aiter__(self) -> AsyncIterator[str | ToolCall | ToolCallResult | BasePrompt | Usage | SimpleNamespace]: + def __aiter__( + self, + ) -> AsyncIterator[str | ToolCall | ToolCallResult | BasePrompt | Usage | SimpleNamespace | DownstreamAgentResult]: return self - async def __anext__(self) -> str | ToolCall | ToolCallResult | BasePrompt | Usage | SimpleNamespace: + async def __anext__( + self, + ) -> str | ToolCall | ToolCallResult | BasePrompt | Usage | SimpleNamespace | DownstreamAgentResult: try: item = await self._generator.__anext__() match item: case str(): self.content += item - return item case ToolCall(): - return item + pass case ToolCallResult(): if self.tool_calls is None: self.tool_calls = [] self.tool_calls.append(item) - return item + case DownstreamAgentResult(): + if item.agent_id not in self.downstream: + self.downstream[item.agent_id] = [] + if isinstance(item.item, str | ToolCall | ToolCallResult): + self.downstream[item.agent_id].append(item.item) case BasePrompt(): item.add_assistant_message(self.content) self.history = item.chat - return item case Usage(): self.usage = item + # continue loop instead of tail recursion return await self.__anext__() case SimpleNamespace(): result_dict = getattr(item, "result", {}) self.content = result_dict.get("content", self.content) self.metadata = result_dict.get("metadata", self.metadata) self.tool_calls = result_dict.get("tool_calls", self.tool_calls) - return item case _: raise ValueError(f"Unexpected item: {item}") + return item + except StopAsyncIteration: raise @@ -399,10 +461,12 @@ async def _run_without_post_processing( break for tool_call in response.tool_calls: - result = await self._execute_tool(tool_call=tool_call, tools_mapping=tools_mapping, context=context) - tool_calls.append(result) - - prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__) + async for result in self._execute_tool( + tool_call=tool_call, tools_mapping=tools_mapping, context=context + ): + if isinstance(result, ToolCallResult): + tool_calls.append(result) + prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__) turn_count += 1 else: @@ -516,13 +580,23 @@ def run_streaming( AgentMaxTurnsExceededError: If the maximum number of turns is exceeded. AgentInvalidPostProcessorError: If the post-processor is invalid. """ - generator = self._stream_internal(input, options, context, tool_choice) + generator = self._stream_internal( + input=input, + options=options, + context=context, + tool_choice=tool_choice, + ) if post_processors: if not allow_non_streaming and any(not p.supports_streaming for p in post_processors): raise AgentInvalidPostProcessorError( reason="Non-streaming post-processors are not allowed when allow_non_streaming is False" ) + + generator = cast( + _AG[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage], + generator, + ) generator = stream_with_post_processing(generator, post_processors, self) return AgentResultStreaming(generator) @@ -533,10 +607,12 @@ async def _stream_internal( options: AgentOptions[LLMClientOptionsT] | None = None, context: AgentRunContext | None = None, tool_choice: ToolChoice | None = None, - ) -> AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage]: + ) -> AsyncGenerator[str | ToolCall | ToolCallResult | DownstreamAgentResult | SimpleNamespace | BasePrompt | Usage]: if context is None: context = AgentRunContext() + context.register_agent(cast(Agent[Any, Any, str], self)) + input = cast(PromptInputT, input) merged_options = (self.default_options | options) if options else self.default_options llm_options = merged_options.llm_options or self.llm.default_options @@ -546,24 +622,33 @@ async def _stream_internal( turn_count = 0 max_turns = merged_options.max_turns max_turns = 10 if max_turns is NOT_GIVEN else max_turns + with trace(input=input, options=merged_options) as outputs: while not max_turns or turn_count < max_turns: returned_tool_call = False self._check_token_limits(merged_options, context.usage, prompt_with_history, self.llm) + streaming_result = self.llm.generate_streaming( prompt=prompt_with_history, tools=[tool.to_function_schema() for tool in tools_mapping.values()], tool_choice=tool_choice if tool_choice and turn_count == 0 else None, options=self._get_llm_options(llm_options, merged_options, context.usage), ) + async for chunk in streaming_result: yield chunk if isinstance(chunk, ToolCall): - result = await self._execute_tool(tool_call=chunk, tools_mapping=tools_mapping, context=context) - yield result - prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__) - returned_tool_call = True + async for result in self._execute_tool( + tool_call=chunk, + tools_mapping=tools_mapping, + context=context, + ): + yield result + if isinstance(result, ToolCallResult): + prompt_with_history = prompt_with_history.add_tool_use_message(**result.__dict__) + returned_tool_call = True + turn_count += 1 if streaming_result.usage: context.usage += streaming_result.usage @@ -693,11 +778,10 @@ async def _execute_tool( self, tool_call: ToolCall, tools_mapping: dict[str, Tool], - context: AgentRunContext | None = None, - ) -> ToolCallResult: + context: AgentRunContext, + ) -> AsyncGenerator[ToolCallResult | DownstreamAgentResult, None]: if tool_call.type != "function": raise AgentToolNotSupportedError(tool_call.type) - if tool_call.name not in tools_mapping: raise AgentToolNotAvailableError(tool_call.name) @@ -716,8 +800,9 @@ async def _execute_tool( ) if isinstance(tool_output, AgentResultStreaming): - async for _ in tool_output: - pass + async for downstream_item in tool_output: + if context.stream_downstream_events: + yield DownstreamAgentResult(agent_id=tool.id, item=downstream_item) tool_output = { "content": tool_output.content, @@ -738,7 +823,7 @@ async def _execute_tool( } raise AgentToolExecutionError(tool_call.name, e) from e - return ToolCallResult( + yield ToolCallResult( id=tool_call.id, name=tool_call.name, arguments=tool_call.arguments, diff --git a/packages/ragbits-agents/src/ragbits/agents/tool.py b/packages/ragbits-agents/src/ragbits/agents/tool.py index 3da403782..c8e279f29 100644 --- a/packages/ragbits-agents/src/ragbits/agents/tool.py +++ b/packages/ragbits-agents/src/ragbits/agents/tool.py @@ -1,7 +1,7 @@ from collections.abc import Callable from contextlib import suppress from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast from pydantic import BaseModel from typing_extensions import Self @@ -12,7 +12,7 @@ from ragbits.core.utils.function_schema import convert_function_to_function_schema, get_context_variable_name if TYPE_CHECKING: - from ragbits.agents import Agent, AgentResultStreaming + from ragbits.agents import Agent, AgentResultStreaming, AgentRunContext with suppress(ImportError): from pydantic_ai import Tool as PydanticAITool @@ -50,6 +50,7 @@ class Tool: """The actual callable function to execute when the tool is called.""" context_var_name: str | None = None """The name of the context variable that this tool accepts.""" + id: str | None = None @classmethod def from_callable(cls, callable: Callable) -> Self: @@ -150,20 +151,28 @@ def from_agent( parameters = {"type": "object", "properties": properties, "required": required} + context_var_name = get_context_variable_name(agent.run) + def _on_tool_call(**kwargs: dict) -> "AgentResultStreaming": + if context_var_name: + context = cast("AgentRunContext[Any] | None", kwargs.get(context_var_name)) + if context is not None: + context.register_agent(cast("Agent[Any, Any, str]", agent)) + if input_model_cls and issubclass(input_model_cls, BaseModel): model_input = input_model_cls(**kwargs) else: model_input = kwargs.get("input") - return agent.run_streaming(model_input) + return agent.run_streaming(model_input, context=context) return cls( name=variable_name, + id=agent.id, description=description, parameters=parameters, on_tool_call=_on_tool_call, - context_var_name=get_context_variable_name(agent.run), + context_var_name=context_var_name, )