8
8
9
9
import json
10
10
import uuid
11
- from collections .abc import AsyncIterator , Iterable , Mapping , Sequence
11
+ from collections .abc import AsyncIterator , Awaitable , Iterable , Mapping , Sequence
12
12
from dataclasses import Field , dataclass , replace
13
13
from http import HTTPStatus
14
14
from typing import (
19
19
Generic ,
20
20
Protocol ,
21
21
TypeVar ,
22
+ Union ,
22
23
runtime_checkable ,
23
24
)
24
25
25
26
from pydantic import BaseModel , ValidationError
26
27
28
+ from pydantic_ai import _utils
29
+
27
30
from ._agent_graph import CallToolsNode , ModelRequestNode
28
31
from .agent import Agent , AgentRun
29
32
from .exceptions import UserError
104
107
'StateDeps' ,
105
108
'StateHandler' ,
106
109
'AGUIApp' ,
110
+ 'AgentRunCallback' ,
107
111
'handle_ag_ui_request' ,
108
112
'run_ag_ui' ,
109
113
]
110
114
111
115
SSE_CONTENT_TYPE : Final [str ] = 'text/event-stream'
112
116
"""Content type header value for Server-Sent Events (SSE)."""
113
117
118
+ AgentRunCallback = Callable [[AgentRun [Any , Any ]], Union [None , Awaitable [None ]]]
119
+ """Callback function type that receives the completed AgentRun. Can be sync or async."""
120
+
114
121
115
122
class AGUIApp (Generic [AgentDepsT , OutputDataT ], Starlette ):
116
123
"""ASGI application for running Pydantic AI agents with AG-UI protocol support."""
@@ -158,7 +165,6 @@ def __init__(
158
165
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
159
166
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
160
167
toolsets: Optional additional toolsets for this run.
161
-
162
168
debug: Boolean indicating if debug tracebacks should be returned on errors.
163
169
routes: A list of routes to serve incoming HTTP and WebSocket requests.
164
170
middleware: A list of middleware to run for every request. A starlette application will always
@@ -217,6 +223,7 @@ async def handle_ag_ui_request(
217
223
usage : Usage | None = None ,
218
224
infer_name : bool = True ,
219
225
toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
226
+ on_complete : AgentRunCallback | None = None ,
220
227
) -> Response :
221
228
"""Handle an AG-UI request by running the agent and returning a streaming response.
222
229
@@ -233,6 +240,8 @@ async def handle_ag_ui_request(
233
240
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
234
241
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
235
242
toolsets: Optional additional toolsets for this run.
243
+ on_complete: Optional callback function called when the agent run completes successfully.
244
+ The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data.
236
245
237
246
Returns:
238
247
A streaming Starlette response with AG-UI protocol events.
@@ -260,6 +269,7 @@ async def handle_ag_ui_request(
260
269
usage = usage ,
261
270
infer_name = infer_name ,
262
271
toolsets = toolsets ,
272
+ on_complete = on_complete ,
263
273
),
264
274
media_type = accept ,
265
275
)
@@ -278,6 +288,7 @@ async def run_ag_ui(
278
288
usage : Usage | None = None ,
279
289
infer_name : bool = True ,
280
290
toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
291
+ on_complete : AgentRunCallback | None = None ,
281
292
) -> AsyncIterator [str ]:
282
293
"""Run the agent with the AG-UI run input and stream AG-UI protocol events.
283
294
@@ -295,6 +306,8 @@ async def run_ag_ui(
295
306
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
296
307
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
297
308
toolsets: Optional additional toolsets for this run.
309
+ on_complete: Optional callback function called when the agent run completes successfully.
310
+ The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data.
298
311
299
312
Yields:
300
313
Streaming event chunks encoded as strings according to the accept header value.
@@ -316,6 +329,7 @@ async def run_ag_ui(
316
329
)
317
330
toolsets = [* toolsets , toolset ] if toolsets else [toolset ]
318
331
332
+ completed_run : AgentRun [AgentDepsT , Any ] | None = None
319
333
try :
320
334
yield encoder .encode (
321
335
RunStartedEvent (
@@ -362,6 +376,11 @@ async def run_ag_ui(
362
376
) as run :
363
377
async for event in _agent_stream (run ):
364
378
yield encoder .encode (event )
379
+ if on_complete is not None :
380
+ if _utils .is_async_callable (on_complete ):
381
+ await _utils .run_in_executor (on_complete , run )
382
+ else :
383
+ on_complete (completed_run )
365
384
except _RunError as e :
366
385
yield encoder .encode (
367
386
RunErrorEvent (message = e .message , code = e .code ),
0 commit comments