Skip to content

Commit 0871ac7

Browse files
committed
Add UIApp, AGUIApp, VercelAIApp
1 parent 5bcc597 commit 0871ac7

File tree

13 files changed

+196
-133
lines changed

13 files changed

+196
-133
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 6 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
from __future__ import annotations
88

9-
from collections.abc import AsyncIterator, Callable, Mapping, Sequence
10-
from typing import Any, Generic
9+
from collections.abc import AsyncIterator, Sequence
10+
from typing import Any
1111

1212
from . import DeferredToolResults
1313
from .agent import AbstractAgent
1414
from .messages import ModelMessage
1515
from .models import KnownModelName, Model
16-
from .output import OutputDataT, OutputSpec
16+
from .output import OutputSpec
1717
from .settings import ModelSettings
1818
from .tools import AgentDepsT
1919
from .toolsets import AbstractToolset
@@ -27,6 +27,7 @@
2727
from .ui.ag_ui import (
2828
SSE_CONTENT_TYPE,
2929
AGUIAdapter,
30+
AGUIApp,
3031
)
3132
except ImportError as e: # pragma: no cover
3233
raise ImportError(
@@ -35,12 +36,8 @@
3536
) from e
3637

3738
try:
38-
from starlette.applications import Starlette
39-
from starlette.middleware import Middleware
4039
from starlette.requests import Request
4140
from starlette.responses import Response
42-
from starlette.routing import BaseRoute
43-
from starlette.types import ExceptionHandler, Lifespan
4441
except ImportError as e: # pragma: no cover
4542
raise ImportError(
4643
'Please install the `starlette` package to use `Agent.to_ag_ui()` method, '
@@ -59,105 +56,6 @@
5956
]
6057

6158

62-
class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
63-
"""ASGI application for running Pydantic AI agents with AG-UI protocol support."""
64-
65-
def __init__(
66-
self,
67-
agent: AbstractAgent[AgentDepsT, OutputDataT],
68-
*,
69-
# Agent.iter parameters.
70-
output_type: OutputSpec[Any] | None = None,
71-
message_history: Sequence[ModelMessage] | None = None,
72-
deferred_tool_results: DeferredToolResults | None = None,
73-
model: Model | KnownModelName | str | None = None,
74-
deps: AgentDepsT = None,
75-
model_settings: ModelSettings | None = None,
76-
usage_limits: UsageLimits | None = None,
77-
usage: RunUsage | None = None,
78-
infer_name: bool = True,
79-
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
80-
# Starlette parameters.
81-
debug: bool = False,
82-
routes: Sequence[BaseRoute] | None = None,
83-
middleware: Sequence[Middleware] | None = None,
84-
exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
85-
on_startup: Sequence[Callable[[], Any]] | None = None,
86-
on_shutdown: Sequence[Callable[[], Any]] | None = None,
87-
lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None,
88-
) -> None:
89-
"""An ASGI application that handles every AG-UI request by running the agent.
90-
91-
Note that the `deps` will be the same for each request, with the exception of the AG-UI state that's
92-
injected into the `state` field of a `deps` object that implements the [`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol.
93-
To provide different `deps` for each request (e.g. based on the authenticated user),
94-
use [`pydantic_ai.ag_ui.run_ag_ui`][pydantic_ai.ag_ui.run_ag_ui] or
95-
[`pydantic_ai.ag_ui.handle_ag_ui_request`][pydantic_ai.ag_ui.handle_ag_ui_request] instead.
96-
97-
Args:
98-
agent: The agent to run.
99-
100-
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has
101-
no output validators since output validators would expect an argument that matches the agent's
102-
output type.
103-
message_history: History of the conversation so far.
104-
deferred_tool_results: Optional results for deferred tool calls in the message history.
105-
model: Optional model to use for this run, required if `model` was not set when creating the agent.
106-
deps: Optional dependencies to use for this run.
107-
model_settings: Optional settings to use for this model's request.
108-
usage_limits: Optional limits on model request count or token usage.
109-
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
110-
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
111-
toolsets: Optional additional toolsets for this run.
112-
113-
debug: Boolean indicating if debug tracebacks should be returned on errors.
114-
routes: A list of routes to serve incoming HTTP and WebSocket requests.
115-
middleware: A list of middleware to run for every request. A starlette application will always
116-
automatically include two middleware classes. `ServerErrorMiddleware` is added as the very
117-
outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack.
118-
`ExceptionMiddleware` is added as the very innermost middleware, to deal with handled
119-
exception cases occurring in the routing or endpoints.
120-
exception_handlers: A mapping of either integer status codes, or exception class types onto
121-
callables which handle the exceptions. Exception handler callables should be of the form
122-
`handler(request, exc) -> response` and may be either standard functions, or async functions.
123-
on_startup: A list of callables to run on application startup. Startup handler callables do not
124-
take any arguments, and may be either standard functions, or async functions.
125-
on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do
126-
not take any arguments, and may be either standard functions, or async functions.
127-
lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks.
128-
This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or
129-
the other, not both.
130-
"""
131-
super().__init__(
132-
debug=debug,
133-
routes=routes,
134-
middleware=middleware,
135-
exception_handlers=exception_handlers,
136-
on_startup=on_startup,
137-
on_shutdown=on_shutdown,
138-
lifespan=lifespan,
139-
)
140-
141-
async def endpoint(request: Request) -> Response:
142-
"""Endpoint to run the agent with the provided input data."""
143-
return await handle_ag_ui_request(
144-
agent,
145-
request,
146-
output_type=output_type,
147-
message_history=message_history,
148-
deferred_tool_results=deferred_tool_results,
149-
model=model,
150-
deps=deps,
151-
model_settings=model_settings,
152-
usage_limits=usage_limits,
153-
usage=usage,
154-
infer_name=infer_name,
155-
toolsets=toolsets,
156-
)
157-
158-
self.router.add_route('/', endpoint, methods=['POST'], name='run_agent')
159-
160-
16159
async def handle_ag_ui_request(
16260
agent: AbstractAgent[AgentDepsT, Any],
16361
request: Request,
@@ -202,6 +100,8 @@ async def handle_ag_ui_request(
202100
request,
203101
deps=deps,
204102
output_type=output_type,
103+
message_history=message_history,
104+
deferred_tool_results=deferred_tool_results,
205105
model=model,
206106
model_settings=model_settings,
207107
usage_limits=usage_limits,

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,7 @@ async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]:
993993
async def __aexit__(self, *args: Any) -> bool | None:
994994
raise NotImplementedError
995995

996+
# TODO (v2): Remove in favor of using `AGUIApp` directly -- we don't have `to_temporal()` or `to_vercel_ai()` either.
996997
def to_ag_ui(
997998
self,
998999
*,

pydantic_ai_slim/pydantic_ai/ui/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66

77
from __future__ import annotations
88

9-
from .adapter import BaseAdapter, OnCompleteFunc, StateDeps, StateHandler
10-
from .event_stream import SSE_CONTENT_TYPE, BaseEventStream
9+
from .adapter import OnCompleteFunc, StateDeps, StateHandler, UIAdapter
10+
from .app import UIApp
11+
from .event_stream import SSE_CONTENT_TYPE, UIEventStream
1112

1213
__all__ = [
13-
'BaseAdapter',
14-
'BaseEventStream',
14+
'UIAdapter',
15+
'UIEventStream',
1516
'SSE_CONTENT_TYPE',
1617
'StateDeps',
1718
'StateHandler',
1819
'OnCompleteFunc',
20+
'UIApp',
1921
]

pydantic_ai_slim/pydantic_ai/ui/adapter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@
3333
from ..settings import ModelSettings
3434
from ..toolsets import AbstractToolset
3535
from ..usage import RunUsage, UsageLimits
36-
from .event_stream import BaseEventStream, OnCompleteFunc, SourceEvent
36+
from .event_stream import OnCompleteFunc, SourceEvent, UIEventStream
3737

3838
if TYPE_CHECKING:
3939
from starlette.requests import Request
4040
from starlette.responses import Response
4141

4242

4343
__all__ = [
44-
'BaseAdapter',
44+
'UIAdapter',
4545
]
4646

4747

@@ -107,7 +107,7 @@ class StateDeps(Generic[StateT]):
107107

108108

109109
@dataclass
110-
class BaseAdapter(ABC, Generic[RunRequestT, MessageT, EventT, AgentDepsT, OutputDataT]):
110+
class UIAdapter(ABC, Generic[RunRequestT, MessageT, EventT, AgentDepsT, OutputDataT]):
111111
"""TODO (DouwM): Docstring."""
112112

113113
agent: AbstractAgent[AgentDepsT, OutputDataT]
@@ -131,7 +131,7 @@ def load_messages(cls, messages: Sequence[MessageT]) -> list[ModelMessage]:
131131
@abstractmethod
132132
def build_event_stream(
133133
self, accept: str | None = None
134-
) -> BaseEventStream[RunRequestT, EventT, AgentDepsT, OutputDataT]:
134+
) -> UIEventStream[RunRequestT, EventT, AgentDepsT, OutputDataT]:
135135
"""Create an event stream for the adapter.
136136
137137
Args:
@@ -235,7 +235,7 @@ async def run_stream(
235235
toolset = self.toolset
236236
if toolset:
237237
output_type = [output_type or self.agent.output_type, DeferredToolRequests]
238-
toolsets = [*toolsets, toolset] if toolsets else [toolset]
238+
toolsets = [*(toolsets or []), toolset]
239239

240240
if isinstance(deps, StateHandler):
241241
raw_state = self.state or {}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
11
"""AG-UI protocol integration for Pydantic AI agents."""
22

3+
from typing import Any
4+
5+
from pydantic_ai.agent import AbstractAgent
6+
from pydantic_ai.output import OutputDataT
7+
from pydantic_ai.tools import AgentDepsT
8+
9+
from .. import UIApp
310
from ._adapter import AGUIAdapter
411
from ._event_stream import SSE_CONTENT_TYPE, AGUIEventStream
512

613
__all__ = [
714
'AGUIAdapter',
815
'AGUIEventStream',
916
'SSE_CONTENT_TYPE',
17+
'AGUIApp',
1018
]
19+
20+
21+
class AGUIApp(UIApp[AgentDepsT, OutputDataT]):
22+
"""ASGI application for running Pydantic AI agents with AG-UI protocol support."""
23+
24+
def __init__(self, agent: AbstractAgent[AgentDepsT, OutputDataT], **kwargs: Any) -> None:
25+
super().__init__(AGUIAdapter[AgentDepsT, OutputDataT], agent, **kwargs)

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
UserMessage,
4242
)
4343

44-
from ..adapter import BaseAdapter
45-
from ..event_stream import BaseEventStream
44+
from ..adapter import UIAdapter
45+
from ..event_stream import UIEventStream
4646
from ._event_stream import BUILTIN_TOOL_CALL_ID_PREFIX, AGUIEventStream
4747
except ImportError as e: # pragma: no cover
4848
raise ImportError(
@@ -93,7 +93,7 @@ def label(self) -> str:
9393
return 'the AG-UI frontend tools' # pragma: no cover
9494

9595

96-
class AGUIAdapter(BaseAdapter[RunAgentInput, Message, BaseEvent, AgentDepsT, OutputDataT]):
96+
class AGUIAdapter(UIAdapter[RunAgentInput, Message, BaseEvent, AgentDepsT, OutputDataT]):
9797
"""TODO (DouwM): Docstring."""
9898

9999
@classmethod
@@ -103,7 +103,7 @@ async def validate_request(cls, request: Request) -> RunAgentInput:
103103

104104
def build_event_stream(
105105
self, accept: str | None = None
106-
) -> BaseEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]:
106+
) -> UIEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]:
107107
"""Create an event stream for the adapter.
108108
109109
Args:

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from ...output import OutputDataT
2828
from ...tools import AgentDepsT
29-
from .. import SSE_CONTENT_TYPE, BaseEventStream
29+
from .. import SSE_CONTENT_TYPE, UIEventStream
3030

3131
try:
3232
from ag_ui.core import (
@@ -68,7 +68,7 @@
6868

6969

7070
@dataclass
71-
class AGUIEventStream(BaseEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]):
71+
class AGUIEventStream(UIEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]):
7272
"""TODO (DouwM): Docstring."""
7373

7474
_thinking_text: bool = False

0 commit comments

Comments
 (0)