Skip to content

Add on_complete callback to AG-UI functions to get access to AgentRunResult #2429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import json
import uuid
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
from collections.abc import AsyncIterator, Awaitable, Iterable, Mapping, Sequence
from dataclasses import Field, dataclass, replace
from http import HTTPStatus
from typing import (
Expand All @@ -19,11 +19,14 @@
Generic,
Protocol,
TypeVar,
Union,
runtime_checkable,
)

from pydantic import BaseModel, ValidationError

from pydantic_ai import _utils

from ._agent_graph import CallToolsNode, ModelRequestNode
from .agent import Agent, AgentRun
from .exceptions import UserError
Expand Down Expand Up @@ -104,13 +107,17 @@
'StateDeps',
'StateHandler',
'AGUIApp',
'AgentRunCallback',
'handle_ag_ui_request',
'run_ag_ui',
]

SSE_CONTENT_TYPE: Final[str] = 'text/event-stream'
"""Content type header value for Server-Sent Events (SSE)."""

AgentRunCallback = Callable[[AgentRun[Any, Any]], Union[None, Awaitable[None]]]
"""Callback function type that receives the completed AgentRun. Can be sync or async."""


class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
"""ASGI application for running Pydantic AI agents with AG-UI protocol support."""
Expand Down Expand Up @@ -158,7 +165,6 @@ def __init__(
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.

debug: Boolean indicating if debug tracebacks should be returned on errors.
routes: A list of routes to serve incoming HTTP and WebSocket requests.
middleware: A list of middleware to run for every request. A starlette application will always
Expand Down Expand Up @@ -217,6 +223,7 @@ async def handle_ag_ui_request(
usage: Usage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
on_complete: AgentRunCallback | None = None,
) -> Response:
"""Handle an AG-UI request by running the agent and returning a streaming response.

Expand All @@ -233,6 +240,8 @@ async def handle_ag_ui_request(
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
on_complete: Optional callback function called when the agent run completes successfully.
The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data.

Returns:
A streaming Starlette response with AG-UI protocol events.
Expand Down Expand Up @@ -260,6 +269,7 @@ async def handle_ag_ui_request(
usage=usage,
infer_name=infer_name,
toolsets=toolsets,
on_complete=on_complete,
),
media_type=accept,
)
Expand All @@ -278,6 +288,7 @@ async def run_ag_ui(
usage: Usage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
on_complete: AgentRunCallback | None = None,
) -> AsyncIterator[str]:
"""Run the agent with the AG-UI run input and stream AG-UI protocol events.

Expand All @@ -295,6 +306,8 @@ async def run_ag_ui(
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
on_complete: Optional callback function called when the agent run completes successfully.
The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data.

Yields:
Streaming event chunks encoded as strings according to the accept header value.
Expand Down Expand Up @@ -362,6 +375,11 @@ async def run_ag_ui(
) as run:
async for event in _agent_stream(run):
yield encoder.encode(event)
if on_complete is not None:
if _utils.is_async_callable(on_complete):
await on_complete(run)
else:
await _utils.run_in_executor(on_complete, run)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DouweM not sure this is what you meant...

if _utils.is_async_callable(on_complete):
    await on_complete(run)

"None" is not awaitable
"None" is incompatible with protocol "Awaitable[_T_co@Awaitable]"
"await" is not present

And

else:
     await _utils.run_in_executor(on_complete, run)

Is saying it's unreachable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I just ran in my project with both a sync and async callback and it seems to work as expected... Maybe just a type issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, I'm seeing the same, looks like is_async_callable's TypeIs return annotation is subtly broken... I'm Ok with going back to what you had in that case, running this in an executor (i.e. a thread) seems less important than doing so for tools anyway.

except _RunError as e:
yield encoder.encode(
RunErrorEvent(message=e.message, code=e.code),
Expand Down
1 change: 0 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,7 +1905,6 @@ def to_ag_ui(
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.

debug: Boolean indicating if debug tracebacks should be returned on errors.
routes: A list of routes to serve incoming HTTP and WebSocket requests.
middleware: A list of middleware to run for every request. A starlette application will always
Expand Down
119 changes: 119 additions & 0 deletions tests/test_ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,3 +1253,122 @@ async def test_to_ag_ui() -> None:
events.append(json.loads(line.removeprefix('data: ')))

assert events == simple_result()


async def test_callback_sync() -> None:
"""Test that sync callbacks work correctly."""
from pydantic_ai.agent import AgentRun

captured_runs: list[AgentRun[Any, Any]] = []

def sync_callback(agent_run: AgentRun[Any, Any]) -> None:
captured_runs.append(agent_run)

agent = Agent(TestModel())
run_input = create_input(
UserMessage(
id='msg1',
content='Hello!',
)
)

events: list[dict[str, Any]] = []
async for event in run_ag_ui(agent, run_input, on_complete=sync_callback):
events.append(json.loads(event.removeprefix('data: ')))

# Verify callback was called
assert len(captured_runs) == 1
agent_run = captured_runs[0]

# Verify we can access messages
assert agent_run.result is not None, 'AgentRun result should be available in callback'
messages = agent_run.result.all_messages()
assert len(messages) >= 1

# Verify events were still streamed normally
assert len(events) > 0
assert events[0]['type'] == 'RUN_STARTED'
assert events[-1]['type'] == 'RUN_FINISHED'


async def test_callback_async() -> None:
"""Test that async callbacks work correctly."""
from pydantic_ai.agent import AgentRun

captured_runs: list[AgentRun[Any, Any]] = []

async def async_callback(agent_run: AgentRun[Any, Any]) -> None:
captured_runs.append(agent_run)

agent = Agent(TestModel())
run_input = create_input(
UserMessage(
id='msg1',
content='Hello!',
)
)

events: list[dict[str, Any]] = []
async for event in run_ag_ui(agent, run_input, on_complete=async_callback):
events.append(json.loads(event.removeprefix('data: ')))

# Verify callback was called
assert len(captured_runs) == 1
agent_run = captured_runs[0]

# Verify we can access messages
assert agent_run.result is not None, 'AgentRun result should be available in callback'
messages = agent_run.result.all_messages()
assert len(messages) >= 1

# Verify events were still streamed normally
assert len(events) > 0
assert events[0]['type'] == 'RUN_STARTED'
assert events[-1]['type'] == 'RUN_FINISHED'


async def test_callback_none() -> None:
"""Test that passing None for callback works (backwards compatibility)."""

agent = Agent(TestModel())
run_input = create_input(
UserMessage(
id='msg1',
content='Hello!',
)
)

events: list[dict[str, Any]] = []
async for event in run_ag_ui(agent, run_input, on_complete=None):
events.append(json.loads(event.removeprefix('data: ')))

# Verify events were still streamed normally
assert len(events) > 0
assert events[0]['type'] == 'RUN_STARTED'
assert events[-1]['type'] == 'RUN_FINISHED'


async def test_callback_with_error() -> None:
"""Test that callbacks are not called when errors occur."""
from pydantic_ai.agent import AgentRun

captured_runs: list[AgentRun[Any, Any]] = []

def error_callback(agent_run: AgentRun[Any, Any]) -> None:
captured_runs.append(agent_run)

agent = Agent(TestModel())
# Empty messages should cause an error
run_input = create_input() # No messages will cause _NoMessagesError

events: list[dict[str, Any]] = []
async for event in run_ag_ui(agent, run_input, on_complete=error_callback):
events.append(json.loads(event.removeprefix('data: ')))

# Verify callback was not called due to error
assert len(captured_runs) == 0

# Verify error event was sent
assert len(events) > 0
assert events[0]['type'] == 'RUN_STARTED'
assert any(event['type'] == 'RUN_ERROR' for event in events)
Loading