|
6 | 6 |
|
7 | 7 | from __future__ import annotations |
8 | 8 |
|
9 | | -from collections.abc import AsyncIterator, Callable, Mapping, Sequence |
10 | | -from typing import Any, Generic |
| 9 | +from collections.abc import AsyncIterator, Sequence |
| 10 | +from typing import Any |
11 | 11 |
|
12 | 12 | from . import DeferredToolResults |
13 | 13 | from .agent import AbstractAgent |
14 | 14 | from .messages import ModelMessage |
15 | 15 | from .models import KnownModelName, Model |
16 | | -from .output import OutputDataT, OutputSpec |
| 16 | +from .output import OutputSpec |
17 | 17 | from .settings import ModelSettings |
18 | 18 | from .tools import AgentDepsT |
19 | 19 | from .toolsets import AbstractToolset |
|
27 | 27 | from .ui.ag_ui import ( |
28 | 28 | SSE_CONTENT_TYPE, |
29 | 29 | AGUIAdapter, |
| 30 | + AGUIApp, |
30 | 31 | ) |
31 | 32 | except ImportError as e: # pragma: no cover |
32 | 33 | raise ImportError( |
|
35 | 36 | ) from e |
36 | 37 |
|
37 | 38 | try: |
38 | | - from starlette.applications import Starlette |
39 | | - from starlette.middleware import Middleware |
40 | 39 | from starlette.requests import Request |
41 | 40 | from starlette.responses import Response |
42 | | - from starlette.routing import BaseRoute |
43 | | - from starlette.types import ExceptionHandler, Lifespan |
44 | 41 | except ImportError as e: # pragma: no cover |
45 | 42 | raise ImportError( |
46 | 43 | 'Please install the `starlette` package to use `Agent.to_ag_ui()` method, ' |
|
59 | 56 | ] |
60 | 57 |
|
61 | 58 |
|
62 | | -class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette): |
63 | | - """ASGI application for running Pydantic AI agents with AG-UI protocol support.""" |
64 | | - |
65 | | - def __init__( |
66 | | - self, |
67 | | - agent: AbstractAgent[AgentDepsT, OutputDataT], |
68 | | - *, |
69 | | - # Agent.iter parameters. |
70 | | - output_type: OutputSpec[Any] | None = None, |
71 | | - message_history: Sequence[ModelMessage] | None = None, |
72 | | - deferred_tool_results: DeferredToolResults | None = None, |
73 | | - model: Model | KnownModelName | str | None = None, |
74 | | - deps: AgentDepsT = None, |
75 | | - model_settings: ModelSettings | None = None, |
76 | | - usage_limits: UsageLimits | None = None, |
77 | | - usage: RunUsage | None = None, |
78 | | - infer_name: bool = True, |
79 | | - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, |
80 | | - # Starlette parameters. |
81 | | - debug: bool = False, |
82 | | - routes: Sequence[BaseRoute] | None = None, |
83 | | - middleware: Sequence[Middleware] | None = None, |
84 | | - exception_handlers: Mapping[Any, ExceptionHandler] | None = None, |
85 | | - on_startup: Sequence[Callable[[], Any]] | None = None, |
86 | | - on_shutdown: Sequence[Callable[[], Any]] | None = None, |
87 | | - lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None, |
88 | | - ) -> None: |
89 | | - """An ASGI application that handles every AG-UI request by running the agent. |
90 | | -
|
91 | | - Note that the `deps` will be the same for each request, with the exception of the AG-UI state that's |
92 | | - injected into the `state` field of a `deps` object that implements the [`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol. |
93 | | - To provide different `deps` for each request (e.g. based on the authenticated user), |
94 | | - use [`pydantic_ai.ag_ui.run_ag_ui`][pydantic_ai.ag_ui.run_ag_ui] or |
95 | | - [`pydantic_ai.ag_ui.handle_ag_ui_request`][pydantic_ai.ag_ui.handle_ag_ui_request] instead. |
96 | | -
|
97 | | - Args: |
98 | | - agent: The agent to run. |
99 | | -
|
100 | | - output_type: Custom output type to use for this run, `output_type` may only be used if the agent has |
101 | | - no output validators since output validators would expect an argument that matches the agent's |
102 | | - output type. |
103 | | - message_history: History of the conversation so far. |
104 | | - deferred_tool_results: Optional results for deferred tool calls in the message history. |
105 | | - model: Optional model to use for this run, required if `model` was not set when creating the agent. |
106 | | - deps: Optional dependencies to use for this run. |
107 | | - model_settings: Optional settings to use for this model's request. |
108 | | - usage_limits: Optional limits on model request count or token usage. |
109 | | - usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. |
110 | | - infer_name: Whether to try to infer the agent name from the call frame if it's not set. |
111 | | - toolsets: Optional additional toolsets for this run. |
112 | | -
|
113 | | - debug: Boolean indicating if debug tracebacks should be returned on errors. |
114 | | - routes: A list of routes to serve incoming HTTP and WebSocket requests. |
115 | | - middleware: A list of middleware to run for every request. A starlette application will always |
116 | | - automatically include two middleware classes. `ServerErrorMiddleware` is added as the very |
117 | | - outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. |
118 | | - `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled |
119 | | - exception cases occurring in the routing or endpoints. |
120 | | - exception_handlers: A mapping of either integer status codes, or exception class types onto |
121 | | - callables which handle the exceptions. Exception handler callables should be of the form |
122 | | - `handler(request, exc) -> response` and may be either standard functions, or async functions. |
123 | | - on_startup: A list of callables to run on application startup. Startup handler callables do not |
124 | | - take any arguments, and may be either standard functions, or async functions. |
125 | | - on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do |
126 | | - not take any arguments, and may be either standard functions, or async functions. |
127 | | - lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. |
128 | | - This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or |
129 | | - the other, not both. |
130 | | - """ |
131 | | - super().__init__( |
132 | | - debug=debug, |
133 | | - routes=routes, |
134 | | - middleware=middleware, |
135 | | - exception_handlers=exception_handlers, |
136 | | - on_startup=on_startup, |
137 | | - on_shutdown=on_shutdown, |
138 | | - lifespan=lifespan, |
139 | | - ) |
140 | | - |
141 | | - async def endpoint(request: Request) -> Response: |
142 | | - """Endpoint to run the agent with the provided input data.""" |
143 | | - return await handle_ag_ui_request( |
144 | | - agent, |
145 | | - request, |
146 | | - output_type=output_type, |
147 | | - message_history=message_history, |
148 | | - deferred_tool_results=deferred_tool_results, |
149 | | - model=model, |
150 | | - deps=deps, |
151 | | - model_settings=model_settings, |
152 | | - usage_limits=usage_limits, |
153 | | - usage=usage, |
154 | | - infer_name=infer_name, |
155 | | - toolsets=toolsets, |
156 | | - ) |
157 | | - |
158 | | - self.router.add_route('/', endpoint, methods=['POST'], name='run_agent') |
159 | | - |
160 | | - |
161 | 59 | async def handle_ag_ui_request( |
162 | 60 | agent: AbstractAgent[AgentDepsT, Any], |
163 | 61 | request: Request, |
@@ -202,6 +100,8 @@ async def handle_ag_ui_request( |
202 | 100 | request, |
203 | 101 | deps=deps, |
204 | 102 | output_type=output_type, |
| 103 | + message_history=message_history, |
| 104 | + deferred_tool_results=deferred_tool_results, |
205 | 105 | model=model, |
206 | 106 | model_settings=model_settings, |
207 | 107 | usage_limits=usage_limits, |
|
0 commit comments