|
1 | 1 | """AG-UI protocol integration for Pydantic AI agents.""" |
2 | 2 |
|
3 | | -from typing import Any |
| 3 | +from __future__ import annotations |
4 | 4 |
|
| 5 | +from collections.abc import Callable, Mapping, Sequence |
| 6 | +from typing import Any, Generic |
| 7 | + |
| 8 | +from typing_extensions import Self |
| 9 | + |
| 10 | +from pydantic_ai import DeferredToolResults |
5 | 11 | from pydantic_ai.agent import AbstractAgent |
6 | | -from pydantic_ai.output import OutputDataT |
| 12 | +from pydantic_ai.builtin_tools import AbstractBuiltinTool |
| 13 | +from pydantic_ai.messages import ModelMessage |
| 14 | +from pydantic_ai.models import KnownModelName, Model |
| 15 | +from pydantic_ai.output import OutputDataT, OutputSpec |
| 16 | +from pydantic_ai.settings import ModelSettings |
7 | 17 | from pydantic_ai.tools import AgentDepsT |
| 18 | +from pydantic_ai.toolsets import AbstractToolset |
| 19 | +from pydantic_ai.usage import RunUsage, UsageLimits |
8 | 20 |
|
9 | | -from ..app import UIApp |
| 21 | +from .. import OnCompleteFunc |
10 | 22 | from ._adapter import AGUIAdapter |
11 | 23 |
|
| 24 | +try: |
| 25 | + from starlette.applications import Starlette |
| 26 | + from starlette.middleware import Middleware |
| 27 | + from starlette.requests import Request |
| 28 | + from starlette.responses import Response |
| 29 | + from starlette.routing import BaseRoute |
| 30 | + from starlette.types import ExceptionHandler, Lifespan |
| 31 | +except ImportError as e: # pragma: no cover |
| 32 | + raise ImportError( |
| 33 | + 'Please install the `starlette` package to use `AGUIApp`, ' |
| 34 | + 'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`' |
| 35 | + ) from e |
| 36 | + |
12 | 37 |
|
13 | | -class AGUIApp(UIApp[AgentDepsT, OutputDataT]): |
| 38 | +class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette): |
14 | 39 | """ASGI application for running Pydantic AI agents with AG-UI protocol support.""" |
15 | 40 |
|
16 | | - def __init__(self, agent: AbstractAgent[AgentDepsT, OutputDataT], **kwargs: Any): |
17 | | - super().__init__(AGUIAdapter[AgentDepsT, OutputDataT], agent, **kwargs) |
| 41 | + def __init__( |
| 42 | + self, |
| 43 | + agent: AbstractAgent[AgentDepsT, OutputDataT], |
| 44 | + *, |
| 45 | + # AGUIAdapter.dispatch_request parameters |
| 46 | + output_type: OutputSpec[Any] | None = None, |
| 47 | + message_history: Sequence[ModelMessage] | None = None, |
| 48 | + deferred_tool_results: DeferredToolResults | None = None, |
| 49 | + model: Model | KnownModelName | str | None = None, |
| 50 | + deps: AgentDepsT = None, |
| 51 | + model_settings: ModelSettings | None = None, |
| 52 | + usage_limits: UsageLimits | None = None, |
| 53 | + usage: RunUsage | None = None, |
| 54 | + infer_name: bool = True, |
| 55 | + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, |
| 56 | + builtin_tools: Sequence[AbstractBuiltinTool] | None = None, |
| 57 | + on_complete: OnCompleteFunc[Any] | None = None, |
| 58 | + # Starlette parameters |
| 59 | + debug: bool = False, |
| 60 | + routes: Sequence[BaseRoute] | None = None, |
| 61 | + middleware: Sequence[Middleware] | None = None, |
| 62 | + exception_handlers: Mapping[Any, ExceptionHandler] | None = None, |
| 63 | + on_startup: Sequence[Callable[[], Any]] | None = None, |
| 64 | + on_shutdown: Sequence[Callable[[], Any]] | None = None, |
| 65 | + lifespan: Lifespan[Self] | None = None, |
| 66 | + ) -> None: |
| 67 | + """An ASGI application that handles every request by running the agent and streaming the response. |
| 68 | +
|
| 69 | + Note that the `deps` will be the same for each request, with the exception of the frontend state that's |
| 70 | + injected into the `state` field of a `deps` object that implements the [`StateHandler`][pydantic_ai.ui.StateHandler] protocol. |
| 71 | + To provide different `deps` for each request (e.g. based on the authenticated user), |
| 72 | + use [`AGUIAdapter.run_stream()`][pydantic_ai.ui.ag_ui.AGUIAdapter.run_stream] or |
| 73 | + [`AGUIAdapter.dispatch_request()`][pydantic_ai.ui.ag_ui.AGUIAdapter.dispatch_request] instead. |
| 74 | +
|
| 75 | + Args: |
| 76 | + agent: The agent to run. |
| 77 | +
|
| 78 | + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has |
| 79 | + no output validators since output validators would expect an argument that matches the agent's |
| 80 | + output type. |
| 81 | + message_history: History of the conversation so far. |
| 82 | + deferred_tool_results: Optional results for deferred tool calls in the message history. |
| 83 | + model: Optional model to use for this run, required if `model` was not set when creating the agent. |
| 84 | + deps: Optional dependencies to use for this run. |
| 85 | + model_settings: Optional settings to use for this model's request. |
| 86 | + usage_limits: Optional limits on model request count or token usage. |
| 87 | + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. |
| 88 | + infer_name: Whether to try to infer the agent name from the call frame if it's not set. |
| 89 | + toolsets: Optional additional toolsets for this run. |
| 90 | + builtin_tools: Optional additional builtin tools for this run. |
| 91 | + on_complete: Optional callback function called when the agent run completes successfully. |
| 92 | + The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can access `all_messages()` and other result data. |
| 93 | +
|
| 94 | + debug: Boolean indicating if debug tracebacks should be returned on errors. |
| 95 | + routes: A list of routes to serve incoming HTTP and WebSocket requests. |
| 96 | + middleware: A list of middleware to run for every request. A starlette application will always |
| 97 | + automatically include two middleware classes. `ServerErrorMiddleware` is added as the very |
| 98 | + outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. |
| 99 | + `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled |
| 100 | + exception cases occurring in the routing or endpoints. |
| 101 | + exception_handlers: A mapping of either integer status codes, or exception class types onto |
| 102 | + callables which handle the exceptions. Exception handler callables should be of the form |
| 103 | + `handler(request, exc) -> response` and may be either standard functions, or async functions. |
| 104 | + on_startup: A list of callables to run on application startup. Startup handler callables do not |
| 105 | + take any arguments, and may be either standard functions, or async functions. |
| 106 | + on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do |
| 107 | + not take any arguments, and may be either standard functions, or async functions. |
| 108 | + lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. |
| 109 | + This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or |
| 110 | + the other, not both. |
| 111 | + """ |
| 112 | + super().__init__( |
| 113 | + debug=debug, |
| 114 | + routes=routes, |
| 115 | + middleware=middleware, |
| 116 | + exception_handlers=exception_handlers, |
| 117 | + on_startup=on_startup, |
| 118 | + on_shutdown=on_shutdown, |
| 119 | + lifespan=lifespan, |
| 120 | + ) |
| 121 | + |
| 122 | + async def run_agent(request: Request) -> Response: |
| 123 | + """Endpoint to run the agent with the provided input data.""" |
| 124 | + return await AGUIAdapter.dispatch_request( |
| 125 | + request, |
| 126 | + agent=agent, |
| 127 | + output_type=output_type, |
| 128 | + message_history=message_history, |
| 129 | + deferred_tool_results=deferred_tool_results, |
| 130 | + model=model, |
| 131 | + deps=deps, |
| 132 | + model_settings=model_settings, |
| 133 | + usage_limits=usage_limits, |
| 134 | + usage=usage, |
| 135 | + infer_name=infer_name, |
| 136 | + toolsets=toolsets, |
| 137 | + builtin_tools=builtin_tools, |
| 138 | + on_complete=on_complete, |
| 139 | + ) |
| 140 | + |
| 141 | + self.router.add_route('/', run_agent, methods=['POST']) |
0 commit comments