From 73af2b79a3b1f035a1d2dcdd2c759fb30046a586 Mon Sep 17 00:00:00 2001 From: charles jonas Date: Tue, 5 Aug 2025 09:16:44 -0600 Subject: [PATCH] ag-ui on_complete callback #2398 --- pydantic_ai_slim/pydantic_ai/ag_ui.py | 22 ++++- pydantic_ai_slim/pydantic_ai/agent.py | 1 - tests/test_ag_ui.py | 119 ++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index 1ea6f16eb..7af24887e 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -8,7 +8,7 @@ import json import uuid -from collections.abc import AsyncIterator, Iterable, Mapping, Sequence +from collections.abc import AsyncIterator, Awaitable, Iterable, Mapping, Sequence from dataclasses import Field, dataclass, replace from http import HTTPStatus from typing import ( @@ -19,11 +19,14 @@ Generic, Protocol, TypeVar, + Union, runtime_checkable, ) from pydantic import BaseModel, ValidationError +from pydantic_ai import _utils + from ._agent_graph import CallToolsNode, ModelRequestNode from .agent import Agent, AgentRun from .exceptions import UserError @@ -104,6 +107,7 @@ 'StateDeps', 'StateHandler', 'AGUIApp', + 'AgentRunCallback', 'handle_ag_ui_request', 'run_ag_ui', ] @@ -111,6 +115,9 @@ SSE_CONTENT_TYPE: Final[str] = 'text/event-stream' """Content type header value for Server-Sent Events (SSE).""" +AgentRunCallback = Callable[[AgentRun[Any, Any]], Union[None, Awaitable[None]]] +"""Callback function type that receives the completed AgentRun. Can be sync or async.""" + class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette): """ASGI application for running Pydantic AI agents with AG-UI protocol support.""" @@ -158,7 +165,6 @@ def __init__( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. - debug: Boolean indicating if debug tracebacks should be returned on errors. routes: A list of routes to serve incoming HTTP and WebSocket requests. middleware: A list of middleware to run for every request. A starlette application will always @@ -217,6 +223,7 @@ async def handle_ag_ui_request( usage: Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + on_complete: AgentRunCallback | None = None, ) -> Response: """Handle an AG-UI request by running the agent and returning a streaming response. @@ -233,6 +240,8 @@ async def handle_ag_ui_request( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + on_complete: Optional callback function called when the agent run completes successfully. + The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data. Returns: A streaming Starlette response with AG-UI protocol events. @@ -260,6 +269,7 @@ async def handle_ag_ui_request( usage=usage, infer_name=infer_name, toolsets=toolsets, + on_complete=on_complete, ), media_type=accept, ) @@ -278,6 +288,7 @@ async def run_ag_ui( usage: Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + on_complete: AgentRunCallback | None = None, ) -> AsyncIterator[str]: """Run the agent with the AG-UI run input and stream AG-UI protocol events. @@ -295,6 +306,8 @@ async def run_ag_ui( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + on_complete: Optional callback function called when the agent run completes successfully. + The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data. Yields: Streaming event chunks encoded as strings according to the accept header value. @@ -362,6 +375,11 @@ async def run_ag_ui( ) as run: async for event in _agent_stream(run): yield encoder.encode(event) + if on_complete is not None: + if _utils.is_async_callable(on_complete): + await on_complete(run) + else: + await _utils.run_in_executor(on_complete, run) except _RunError as e: yield encoder.encode( RunErrorEvent(message=e.message, code=e.code), diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 271d0ebc7..f24943772 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1905,7 +1905,6 @@ def to_ag_ui( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. - debug: Boolean indicating if debug tracebacks should be returned on errors. routes: A list of routes to serve incoming HTTP and WebSocket requests. middleware: A list of middleware to run for every request. A starlette application will always diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 8d42fa4c6..2cf9bc493 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -1253,3 +1253,122 @@ async def test_to_ag_ui() -> None: events.append(json.loads(line.removeprefix('data: '))) assert events == simple_result() + + +async def test_callback_sync() -> None: + """Test that sync callbacks work correctly.""" + from pydantic_ai.agent import AgentRun + + captured_runs: list[AgentRun[Any, Any]] = [] + + def sync_callback(agent_run: AgentRun[Any, Any]) -> None: + captured_runs.append(agent_run) + + agent = Agent(TestModel()) + run_input = create_input( + UserMessage( + id='msg1', + content='Hello!', + ) + ) + + events: list[dict[str, Any]] = [] + async for event in run_ag_ui(agent, run_input, on_complete=sync_callback): + events.append(json.loads(event.removeprefix('data: '))) + + # Verify callback was called + assert len(captured_runs) == 1 + agent_run = captured_runs[0] + + # Verify we can access messages + assert agent_run.result is not None, 'AgentRun result should be available in callback' + messages = agent_run.result.all_messages() + assert len(messages) >= 1 + + # Verify events were still streamed normally + assert len(events) > 0 + assert events[0]['type'] == 'RUN_STARTED' + assert events[-1]['type'] == 'RUN_FINISHED' + + +async def test_callback_async() -> None: + """Test that async callbacks work correctly.""" + from pydantic_ai.agent import AgentRun + + captured_runs: list[AgentRun[Any, Any]] = [] + + async def async_callback(agent_run: AgentRun[Any, Any]) -> None: + captured_runs.append(agent_run) + + agent = Agent(TestModel()) + run_input = create_input( + UserMessage( + id='msg1', + content='Hello!', + ) + ) + + events: list[dict[str, Any]] = [] + async for event in run_ag_ui(agent, run_input, on_complete=async_callback): + events.append(json.loads(event.removeprefix('data: '))) + + # Verify callback was called + assert len(captured_runs) == 1 + agent_run = captured_runs[0] + + # Verify we can access messages + assert agent_run.result is not None, 'AgentRun result should be available in callback' + messages = agent_run.result.all_messages() + assert len(messages) >= 1 + + # Verify events were still streamed normally + assert len(events) > 0 + assert events[0]['type'] == 'RUN_STARTED' + assert events[-1]['type'] == 'RUN_FINISHED' + + +async def test_callback_none() -> None: + """Test that passing None for callback works (backwards compatibility).""" + + agent = Agent(TestModel()) + run_input = create_input( + UserMessage( + id='msg1', + content='Hello!', + ) + ) + + events: list[dict[str, Any]] = [] + async for event in run_ag_ui(agent, run_input, on_complete=None): + events.append(json.loads(event.removeprefix('data: '))) + + # Verify events were still streamed normally + assert len(events) > 0 + assert events[0]['type'] == 'RUN_STARTED' + assert events[-1]['type'] == 'RUN_FINISHED' + + +async def test_callback_with_error() -> None: + """Test that callbacks are not called when errors occur.""" + from pydantic_ai.agent import AgentRun + + captured_runs: list[AgentRun[Any, Any]] = [] + + def error_callback(agent_run: AgentRun[Any, Any]) -> None: + captured_runs.append(agent_run) + + agent = Agent(TestModel()) + # Empty messages should cause an error + run_input = create_input() # No messages will cause _NoMessagesError + + events: list[dict[str, Any]] = [] + async for event in run_ag_ui(agent, run_input, on_complete=error_callback): + events.append(json.loads(event.removeprefix('data: '))) + + # Verify callback was not called due to error + assert len(captured_runs) == 0 + + # Verify error event was sent + assert len(events) > 0 + assert events[0]['type'] == 'RUN_STARTED' + assert any(event['type'] == 'RUN_ERROR' for event in events)