diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index c7c1cb2b5c..eae242b920 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -739,6 +739,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ... @overload @@ -758,6 +759,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ... def run_stream_events( @@ -776,6 +778,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: """Run the agent with a user prompt in async mode and stream events from the run. @@ -825,6 +828,7 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run. builtin_tools: Optional additional builtin tools for this run. Returns: @@ -850,6 +854,7 @@ async def main(): usage=usage, toolsets=toolsets, builtin_tools=builtin_tools, + event_stream_handler=event_stream_handler, ) async def _run_stream_events( @@ -867,16 +872,27 @@ async def _run_stream_events( usage: _usage.RunUsage | None = None, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: send_stream, receive_stream = anyio.create_memory_object_stream[ _messages.AgentStreamEvent | AgentRunResultEvent[Any] ]() - async def event_stream_handler( - _: RunContext[AgentDepsT], events: AsyncIterable[_messages.AgentStreamEvent] - ) -> None: + async def _yield_event_stream( + events: AsyncIterable[_messages.AgentStreamEvent], + ) -> AsyncIterator[_messages.AgentStreamEvent]: async for event in events: await send_stream.send(event) + yield event + + async def _event_stream_handler( + context: RunContext[AgentDepsT], events: AsyncIterable[_messages.AgentStreamEvent] + ) -> None: + events = _yield_event_stream(events) + if event_stream_handler is not None: + await event_stream_handler(context, events) + async for _ in events: + pass async def run_agent() -> AgentRunResult[Any]: async with send_stream: @@ -894,7 +910,7 @@ async def run_agent() -> AgentRunResult[Any]: infer_name=False, toolsets=toolsets, builtin_tools=builtin_tools, - event_stream_handler=event_stream_handler, + event_stream_handler=_event_stream_handler, ) task = asyncio.create_task(run_agent()) diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py index 9e1c8ee3c0..4e641f2c55 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py @@ -603,6 +603,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ... @overload @@ -622,6 +623,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ... def run_stream_events( @@ -640,6 +642,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: """Run the agent with a user prompt in async mode and stream events from the run. @@ -689,6 +692,7 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run. builtin_tools: Optional additional builtin tools for this run. Returns: diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py index 8b1b6af44a..5fe4e9aaa6 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py @@ -557,6 +557,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ... @overload @@ -576,6 +577,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ... def run_stream_events( @@ -594,6 +596,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: """Run the agent with a user prompt in async mode and stream events from the run. @@ -643,6 +646,7 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run. builtin_tools: Optional additional builtin tools for this run. Returns: @@ -669,6 +673,7 @@ async def main(): infer_name=infer_name, toolsets=toolsets, builtin_tools=builtin_tools, + event_stream_handler=event_stream_handler, ) @overload diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py index 6e964c8d08..369787458e 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py @@ -628,6 +628,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ... @overload @@ -647,6 +648,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ... def run_stream_events( @@ -665,6 +667,7 @@ def run_stream_events( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]: """Run the agent with a user prompt in async mode and stream events from the run. @@ -714,6 +717,7 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run. builtin_tools: Optional additional builtin tools for this run. Returns: @@ -740,6 +744,7 @@ async def main(): infer_name=infer_name, toolsets=toolsets, builtin_tools=builtin_tools, + event_stream_handler=event_stream_handler, ) @overload diff --git a/pydantic_ai_slim/pydantic_ai/ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/_adapter.py index a1ca12cd6e..bdf44a08de 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/_adapter.py @@ -20,7 +20,7 @@ from pydantic_ai import DeferredToolRequests, DeferredToolResults from pydantic_ai.agent import AbstractAgent -from pydantic_ai.agent.abstract import Instructions +from pydantic_ai.agent.abstract import EventStreamHandler, Instructions from pydantic_ai.builtin_tools import AbstractBuiltinTool from pydantic_ai.messages import ModelMessage from pydantic_ai.models import KnownModelName, Model @@ -209,6 +209,7 @@ def run_stream_native( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> AsyncIterator[NativeEvent]: """Run the agent with the protocol-specific run input and stream Pydantic AI events. @@ -225,6 +226,7 @@ def run_stream_native( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run. builtin_tools: Optional additional builtin tools to use for this run. """ message_history = [*(message_history or []), *self.messages] @@ -262,6 +264,7 @@ def run_stream_native( infer_name=infer_name, toolsets=toolsets, builtin_tools=builtin_tools, + event_stream_handler=event_stream_handler, ) def run_stream( @@ -279,6 +282,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, on_complete: OnCompleteFunc[EventT] | None = None, ) -> AsyncIterator[EventT]: """Run the agent with the protocol-specific run input and stream protocol-specific events. @@ -296,6 +300,7 @@ def run_stream( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run. builtin_tools: Optional additional builtin tools to use for this run. on_complete: Optional callback function called when the agent run completes successfully. The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can optionally yield additional protocol-specific events. @@ -314,6 +319,7 @@ def run_stream( infer_name=infer_name, toolsets=toolsets, builtin_tools=builtin_tools, + event_stream_handler=event_stream_handler, ), on_complete=on_complete, ) @@ -336,6 +342,7 @@ async def dispatch_request( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, on_complete: OnCompleteFunc[EventT] | None = None, ) -> Response: """Handle a protocol-specific HTTP request by running the agent and returning a streaming response of protocol-specific events. @@ -355,6 +362,7 @@ async def dispatch_request( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run. builtin_tools: Optional additional builtin tools to use for this run. on_complete: Optional callback function called when the agent run completes successfully. The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can optionally yield additional protocol-specific events. @@ -393,6 +401,7 @@ async def dispatch_request( infer_name=infer_name, toolsets=toolsets, builtin_tools=builtin_tools, + event_stream_handler=event_stream_handler, on_complete=on_complete, ), ) diff --git a/tests/test_agent.py b/tests/test_agent.py index 8cc6b8b38c..3f9afcdfc3 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -19,6 +19,7 @@ from pydantic_ai import ( AbstractToolset, Agent, + AgentRunResultEvent, AgentStreamEvent, AudioUrl, BinaryContent, @@ -6225,3 +6226,36 @@ def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ] ) assert run.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",') + + +async def test_run_stream_events_with_event_stream_handler(): + async def llm(messages: list[ModelMessage], _info: AgentInfo) -> AsyncIterator[str]: + yield 'ok here is ' + yield 'text' + + messages: list[list[ModelMessage]] = [] + stream_events: list[Any] = [] + + async def event_stream_handler(context: RunContext[Any], events: AsyncIterable[Any]) -> None: + messages.append(context.messages) + async for event in events: + stream_events.append(event) + + agent = Agent(FunctionModel(stream_function=llm)) + agent_events = [ + event + async for event in agent.run_stream_events( + message_history=[ + ModelRequest(parts=[UserPromptPart(content='Hello')]), + ], + event_stream_handler=event_stream_handler, + ) + ] + + assert len(stream_events) == len(agent_events) - 1 + assert stream_events == agent_events[: len(stream_events)] + result_event = next((event for event in agent_events if isinstance(event, AgentRunResultEvent)), None) + assert result_event is not None + all_messages = result_event.result.all_messages() + assert messages + assert all_messages == messages[-1]