Skip to content

Commit a3ad970

Browse files
committed
ag-ui on_complete callback #2398
1 parent 544ff88 commit a3ad970

File tree

3 files changed

+139
-3
lines changed

3 files changed

+139
-3
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import json
1010
import uuid
11-
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
11+
from collections.abc import AsyncIterator, Awaitable, Iterable, Mapping, Sequence
1212
from dataclasses import Field, dataclass, replace
1313
from http import HTTPStatus
1414
from typing import (
@@ -19,11 +19,14 @@
1919
Generic,
2020
Protocol,
2121
TypeVar,
22+
Union,
2223
runtime_checkable,
2324
)
2425

2526
from pydantic import BaseModel, ValidationError
2627

28+
from pydantic_ai import _utils
29+
2730
from ._agent_graph import CallToolsNode, ModelRequestNode
2831
from .agent import Agent, AgentRun
2932
from .exceptions import UserError
@@ -104,13 +107,17 @@
104107
'StateDeps',
105108
'StateHandler',
106109
'AGUIApp',
110+
'AgentRunCallback',
107111
'handle_ag_ui_request',
108112
'run_ag_ui',
109113
]
110114

111115
SSE_CONTENT_TYPE: Final[str] = 'text/event-stream'
112116
"""Content type header value for Server-Sent Events (SSE)."""
113117

118+
AgentRunCallback = Callable[[AgentRun[Any, Any]], Union[None, Awaitable[None]]]
119+
"""Callback function type that receives the completed AgentRun. Can be sync or async."""
120+
114121

115122
class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
116123
"""ASGI application for running Pydantic AI agents with AG-UI protocol support."""
@@ -158,7 +165,6 @@ def __init__(
158165
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
159166
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
160167
toolsets: Optional additional toolsets for this run.
161-
162168
debug: Boolean indicating if debug tracebacks should be returned on errors.
163169
routes: A list of routes to serve incoming HTTP and WebSocket requests.
164170
middleware: A list of middleware to run for every request. A starlette application will always
@@ -217,6 +223,7 @@ async def handle_ag_ui_request(
217223
usage: Usage | None = None,
218224
infer_name: bool = True,
219225
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
226+
on_complete: AgentRunCallback | None = None,
220227
) -> Response:
221228
"""Handle an AG-UI request by running the agent and returning a streaming response.
222229
@@ -233,6 +240,8 @@ async def handle_ag_ui_request(
233240
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
234241
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
235242
toolsets: Optional additional toolsets for this run.
243+
on_complete: Optional callback function called when the agent run completes successfully.
244+
The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data.
236245
237246
Returns:
238247
A streaming Starlette response with AG-UI protocol events.
@@ -260,6 +269,7 @@ async def handle_ag_ui_request(
260269
usage=usage,
261270
infer_name=infer_name,
262271
toolsets=toolsets,
272+
on_complete=on_complete,
263273
),
264274
media_type=accept,
265275
)
@@ -278,6 +288,7 @@ async def run_ag_ui(
278288
usage: Usage | None = None,
279289
infer_name: bool = True,
280290
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
291+
on_complete: AgentRunCallback | None = None,
281292
) -> AsyncIterator[str]:
282293
"""Run the agent with the AG-UI run input and stream AG-UI protocol events.
283294
@@ -295,6 +306,8 @@ async def run_ag_ui(
295306
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
296307
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
297308
toolsets: Optional additional toolsets for this run.
309+
on_complete: Optional callback function called when the agent run completes successfully.
310+
The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data.
298311
299312
Yields:
300313
Streaming event chunks encoded as strings according to the accept header value.
@@ -362,6 +375,11 @@ async def run_ag_ui(
362375
) as run:
363376
async for event in _agent_stream(run):
364377
yield encoder.encode(event)
378+
if on_complete is not None:
379+
if _utils.is_async_callable(on_complete):
380+
await _utils.run_in_executor(on_complete, run)
381+
else:
382+
on_complete(run)
365383
except _RunError as e:
366384
yield encoder.encode(
367385
RunErrorEvent(message=e.message, code=e.code),

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1905,7 +1905,6 @@ def to_ag_ui(
19051905
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
19061906
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
19071907
toolsets: Optional additional toolsets for this run.
1908-
19091908
debug: Boolean indicating if debug tracebacks should be returned on errors.
19101909
routes: A list of routes to serve incoming HTTP and WebSocket requests.
19111910
middleware: A list of middleware to run for every request. A starlette application will always

tests/test_ag_ui.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,3 +1253,122 @@ async def test_to_ag_ui() -> None:
12531253
events.append(json.loads(line.removeprefix('data: ')))
12541254

12551255
assert events == simple_result()
1256+
1257+
1258+
async def test_callback_sync() -> None:
1259+
"""Test that sync callbacks work correctly."""
1260+
from pydantic_ai.agent import AgentRun
1261+
1262+
captured_runs: list[AgentRun[Any, Any]] = []
1263+
1264+
def sync_callback(agent_run: AgentRun[Any, Any]) -> None:
1265+
captured_runs.append(agent_run)
1266+
1267+
agent = Agent(TestModel())
1268+
run_input = create_input(
1269+
UserMessage(
1270+
id='msg1',
1271+
content='Hello!',
1272+
)
1273+
)
1274+
1275+
events: list[dict[str, Any]] = []
1276+
async for event in run_ag_ui(agent, run_input, on_complete=sync_callback):
1277+
events.append(json.loads(event.removeprefix('data: ')))
1278+
1279+
# Verify callback was called
1280+
assert len(captured_runs) == 1
1281+
agent_run = captured_runs[0]
1282+
1283+
# Verify we can access messages
1284+
assert agent_run.result is not None, 'AgentRun result should be available in callback'
1285+
messages = agent_run.result.all_messages()
1286+
assert len(messages) >= 1
1287+
1288+
# Verify events were still streamed normally
1289+
assert len(events) > 0
1290+
assert events[0]['type'] == 'RUN_STARTED'
1291+
assert events[-1]['type'] == 'RUN_FINISHED'
1292+
1293+
1294+
async def test_callback_async() -> None:
1295+
"""Test that async callbacks work correctly."""
1296+
from pydantic_ai.agent import AgentRun
1297+
1298+
captured_runs: list[AgentRun[Any, Any]] = []
1299+
1300+
async def async_callback(agent_run: AgentRun[Any, Any]) -> None:
1301+
captured_runs.append(agent_run)
1302+
1303+
agent = Agent(TestModel())
1304+
run_input = create_input(
1305+
UserMessage(
1306+
id='msg1',
1307+
content='Hello!',
1308+
)
1309+
)
1310+
1311+
events: list[dict[str, Any]] = []
1312+
async for event in run_ag_ui(agent, run_input, on_complete=async_callback):
1313+
events.append(json.loads(event.removeprefix('data: ')))
1314+
1315+
# Verify callback was called
1316+
assert len(captured_runs) == 1
1317+
agent_run = captured_runs[0]
1318+
1319+
# Verify we can access messages
1320+
assert agent_run.result is not None, 'AgentRun result should be available in callback'
1321+
messages = agent_run.result.all_messages()
1322+
assert len(messages) >= 1
1323+
1324+
# Verify events were still streamed normally
1325+
assert len(events) > 0
1326+
assert events[0]['type'] == 'RUN_STARTED'
1327+
assert events[-1]['type'] == 'RUN_FINISHED'
1328+
1329+
1330+
async def test_callback_none() -> None:
1331+
"""Test that passing None for callback works (backwards compatibility)."""
1332+
1333+
agent = Agent(TestModel())
1334+
run_input = create_input(
1335+
UserMessage(
1336+
id='msg1',
1337+
content='Hello!',
1338+
)
1339+
)
1340+
1341+
events: list[dict[str, Any]] = []
1342+
async for event in run_ag_ui(agent, run_input, on_complete=None):
1343+
events.append(json.loads(event.removeprefix('data: ')))
1344+
1345+
# Verify events were still streamed normally
1346+
assert len(events) > 0
1347+
assert events[0]['type'] == 'RUN_STARTED'
1348+
assert events[-1]['type'] == 'RUN_FINISHED'
1349+
1350+
1351+
async def test_callback_with_error() -> None:
1352+
"""Test that callbacks are not called when errors occur."""
1353+
from pydantic_ai.agent import AgentRun
1354+
1355+
captured_runs: list[AgentRun[Any, Any]] = []
1356+
1357+
def error_callback(agent_run: AgentRun[Any, Any]) -> None:
1358+
captured_runs.append(agent_run)
1359+
1360+
agent = Agent(TestModel())
1361+
# Empty messages should cause an error
1362+
run_input = create_input() # No messages will cause _NoMessagesError
1363+
1364+
events: list[dict[str, Any]] = []
1365+
async for event in run_ag_ui(agent, run_input, on_complete=error_callback):
1366+
events.append(json.loads(event.removeprefix('data: ')))
1367+
1368+
# Verify callback was not called due to error
1369+
assert len(captured_runs) == 0
1370+
1371+
# Verify error event was sent
1372+
assert len(events) > 0
1373+
assert events[0]['type'] == 'RUN_STARTED'
1374+
assert any(event['type'] == 'RUN_ERROR' for event in events)

0 commit comments

Comments
 (0)