Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions pydantic_ai_slim/pydantic_ai/agent/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,7 @@ def run_stream_events(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ...

@overload
Expand All @@ -758,6 +759,7 @@ def run_stream_events(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ...

def run_stream_events(
Expand All @@ -776,6 +778,7 @@ def run_stream_events(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]:
"""Run the agent with a user prompt in async mode and stream events from the run.

Expand Down Expand Up @@ -825,6 +828,7 @@ async def main():
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run.
builtin_tools: Optional additional builtin tools for this run.

Returns:
Expand All @@ -850,6 +854,7 @@ async def main():
usage=usage,
toolsets=toolsets,
builtin_tools=builtin_tools,
event_stream_handler=event_stream_handler,
)

async def _run_stream_events(
Expand All @@ -867,16 +872,27 @@ async def _run_stream_events(
usage: _usage.RunUsage | None = None,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]:
send_stream, receive_stream = anyio.create_memory_object_stream[
_messages.AgentStreamEvent | AgentRunResultEvent[Any]
]()

async def event_stream_handler(
_: RunContext[AgentDepsT], events: AsyncIterable[_messages.AgentStreamEvent]
) -> None:
async def _yield_event_stream(
events: AsyncIterable[_messages.AgentStreamEvent],
) -> AsyncIterator[_messages.AgentStreamEvent]:
async for event in events:
await send_stream.send(event)
yield event

async def _event_stream_handler(
context: RunContext[AgentDepsT], events: AsyncIterable[_messages.AgentStreamEvent]
) -> None:
events = _yield_event_stream(events)
if event_stream_handler is not None:
await event_stream_handler(context, events)
async for _ in events:
pass

async def run_agent() -> AgentRunResult[Any]:
async with send_stream:
Expand All @@ -894,7 +910,7 @@ async def run_agent() -> AgentRunResult[Any]:
infer_name=False,
toolsets=toolsets,
builtin_tools=builtin_tools,
event_stream_handler=event_stream_handler,
event_stream_handler=_event_stream_handler,
)

task = asyncio.create_task(run_agent())
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ def run_stream_events(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ...

@overload
Expand All @@ -622,6 +623,7 @@ def run_stream_events(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ...

def run_stream_events(
Expand All @@ -640,6 +642,7 @@ def run_stream_events(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]:
"""Run the agent with a user prompt in async mode and stream events from the run.

Expand Down Expand Up @@ -689,6 +692,7 @@ async def main():
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run.
builtin_tools: Optional additional builtin tools for this run.

Returns:
Expand Down
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/prefect/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def run_stream_events(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ...

@overload
Expand All @@ -576,6 +577,7 @@ def run_stream_events(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ...

def run_stream_events(
Expand All @@ -594,6 +596,7 @@ def run_stream_events(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]:
"""Run the agent with a user prompt in async mode and stream events from the run.

Expand Down Expand Up @@ -643,6 +646,7 @@ async def main():
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run.
builtin_tools: Optional additional builtin tools for this run.

Returns:
Expand All @@ -669,6 +673,7 @@ async def main():
infer_name=infer_name,
toolsets=toolsets,
builtin_tools=builtin_tools,
event_stream_handler=event_stream_handler,
)

@overload
Expand Down
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def run_stream_events(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[OutputDataT]]: ...

@overload
Expand All @@ -647,6 +648,7 @@ def run_stream_events(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[RunOutputDataT]]: ...

def run_stream_events(
Expand All @@ -665,6 +667,7 @@ def run_stream_events(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[_messages.AgentStreamEvent | AgentRunResultEvent[Any]]:
"""Run the agent with a user prompt in async mode and stream events from the run.

Expand Down Expand Up @@ -714,6 +717,7 @@ async def main():
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run.
builtin_tools: Optional additional builtin tools for this run.

Returns:
Expand All @@ -740,6 +744,7 @@ async def main():
infer_name=infer_name,
toolsets=toolsets,
builtin_tools=builtin_tools,
event_stream_handler=event_stream_handler,
)

@overload
Expand Down
11 changes: 10 additions & 1 deletion pydantic_ai_slim/pydantic_ai/ui/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from pydantic_ai import DeferredToolRequests, DeferredToolResults
from pydantic_ai.agent import AbstractAgent
from pydantic_ai.agent.abstract import Instructions
from pydantic_ai.agent.abstract import EventStreamHandler, Instructions
from pydantic_ai.builtin_tools import AbstractBuiltinTool
from pydantic_ai.messages import ModelMessage
from pydantic_ai.models import KnownModelName, Model
Expand Down Expand Up @@ -209,6 +209,7 @@ def run_stream_native(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
) -> AsyncIterator[NativeEvent]:
"""Run the agent with the protocol-specific run input and stream Pydantic AI events.

Expand All @@ -225,6 +226,7 @@ def run_stream_native(
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run.
builtin_tools: Optional additional builtin tools to use for this run.
"""
message_history = [*(message_history or []), *self.messages]
Expand Down Expand Up @@ -262,6 +264,7 @@ def run_stream_native(
infer_name=infer_name,
toolsets=toolsets,
builtin_tools=builtin_tools,
event_stream_handler=event_stream_handler,
)

def run_stream(
Expand All @@ -279,6 +282,7 @@ def run_stream(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
on_complete: OnCompleteFunc[EventT] | None = None,
) -> AsyncIterator[EventT]:
"""Run the agent with the protocol-specific run input and stream protocol-specific events.
Expand All @@ -296,6 +300,7 @@ def run_stream(
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run.
builtin_tools: Optional additional builtin tools to use for this run.
on_complete: Optional callback function called when the agent run completes successfully.
The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can optionally yield additional protocol-specific events.
Expand All @@ -314,6 +319,7 @@ def run_stream(
infer_name=infer_name,
toolsets=toolsets,
builtin_tools=builtin_tools,
event_stream_handler=event_stream_handler,
),
on_complete=on_complete,
)
Expand All @@ -336,6 +342,7 @@ async def dispatch_request(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
on_complete: OnCompleteFunc[EventT] | None = None,
) -> Response:
"""Handle a protocol-specific HTTP request by running the agent and returning a streaming response of protocol-specific events.
Expand All @@ -355,6 +362,7 @@ async def dispatch_request(
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run.
builtin_tools: Optional additional builtin tools to use for this run.
on_complete: Optional callback function called when the agent run completes successfully.
The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can optionally yield additional protocol-specific events.
Expand Down Expand Up @@ -393,6 +401,7 @@ async def dispatch_request(
infer_name=infer_name,
toolsets=toolsets,
builtin_tools=builtin_tools,
event_stream_handler=event_stream_handler,
on_complete=on_complete,
),
)
34 changes: 34 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pydantic_ai import (
AbstractToolset,
Agent,
AgentRunResultEvent,
AgentStreamEvent,
AudioUrl,
BinaryContent,
Expand Down Expand Up @@ -6225,3 +6226,36 @@ def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
]
)
assert run.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",')


async def test_run_stream_events_with_event_stream_handler():
async def llm(messages: list[ModelMessage], _info: AgentInfo) -> AsyncIterator[str]:
yield 'ok here is '
yield 'text'

messages: list[list[ModelMessage]] = []
stream_events: list[Any] = []

async def event_stream_handler(context: RunContext[Any], events: AsyncIterable[Any]) -> None:
messages.append(context.messages)
async for event in events:
stream_events.append(event)

agent = Agent(FunctionModel(stream_function=llm))
agent_events = [
event
async for event in agent.run_stream_events(
message_history=[
ModelRequest(parts=[UserPromptPart(content='Hello')]),
],
event_stream_handler=event_stream_handler,
)
]

assert len(stream_events) == len(agent_events) - 1
assert stream_events == agent_events[: len(stream_events)]
result_event = next((event for event in agent_events if isinstance(event, AgentRunResultEvent)), None)
assert result_event is not None
all_messages = result_event.result.all_messages()
assert messages
assert all_messages == messages[-1]