Skip to content

Commit f0a03d9

Browse files
committed
Claude-assisted refactoring to unify AG-UI and Vercel AI adapters and event streams
1 parent bdd321d commit f0a03d9

File tree

18 files changed

+2045
-1093
lines changed

18 files changed

+2045
-1093
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 29 additions & 542 deletions
Large diffs are not rendered by default.

pydantic_ai_slim/pydantic_ai/ui/__init__.py

Lines changed: 457 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""AG-UI protocol integration for Pydantic AI agents."""
2+
3+
from .adapter import AGUIAdapter
4+
from .event_stream import AGUIEventStream, StateDeps, StateHandler, protocol_messages_to_pai_messages
5+
6+
__all__ = [
7+
'AGUIAdapter',
8+
'AGUIEventStream',
9+
'StateHandler',
10+
'StateDeps',
11+
'protocol_messages_to_pai_messages',
12+
]
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
"""AG-UI adapter for handling requests."""
2+
3+
# pyright: reportGeneralTypeIssues=false, reportInvalidTypeArguments=false
4+
5+
from __future__ import annotations
6+
7+
import json
8+
from dataclasses import dataclass
9+
from http import HTTPStatus
10+
from typing import TYPE_CHECKING, Any
11+
12+
from pydantic import BaseModel, ValidationError
13+
14+
from ...tools import AgentDepsT
15+
from .event_stream import (
16+
AGUIEventStream,
17+
RunAgentInput,
18+
StateHandler,
19+
_AGUIFrontendToolset, # type: ignore[reportPrivateUsage]
20+
_InvalidStateError, # type: ignore[reportPrivateUsage]
21+
_NoMessagesError, # type: ignore[reportPrivateUsage]
22+
_RunError, # type: ignore[reportPrivateUsage]
23+
protocol_messages_to_pai_messages,
24+
)
25+
26+
if TYPE_CHECKING:
27+
from ...agent import Agent
28+
29+
__all__ = ['AGUIAdapter']
30+
31+
32+
@dataclass
33+
class AGUIAdapter:
34+
"""Adapter for handling AG-UI protocol requests with Pydantic AI agents.
35+
36+
This adapter provides an interface for integrating Pydantic AI agents
37+
with the AG-UI protocol, handling request parsing, message conversion,
38+
and event streaming.
39+
40+
Example:
41+
```python
42+
from pydantic_ai import Agent
43+
from pydantic_ai.ui.ag_ui import AGUIAdapter
44+
45+
agent = Agent('openai:gpt-4')
46+
adapter = AGUIAdapter(agent)
47+
48+
async def handle_request(request: RunAgentInput, deps=None):
49+
async for event_str in adapter.run_stream_sse(request, deps):
50+
yield event_str
51+
```
52+
"""
53+
54+
agent: Agent[AgentDepsT]
55+
"""The Pydantic AI agent to run."""
56+
57+
async def run_stream( # noqa: C901
58+
self,
59+
request: RunAgentInput,
60+
deps: AgentDepsT | None = None,
61+
*,
62+
output_type: Any = None,
63+
model: Any = None,
64+
model_settings: Any = None,
65+
usage_limits: Any = None,
66+
usage: Any = None,
67+
infer_name: bool = True,
68+
toolsets: Any = None,
69+
on_complete: Any = None,
70+
):
71+
"""Stream events from an agent run as AG-UI protocol events.
72+
73+
This method provides a complete implementation with all AG-UI features including:
74+
- Frontend tools handling
75+
- State injection
76+
- Error handling (validation vs stream errors)
77+
- on_complete callback
78+
- RunStarted and RunFinished events
79+
80+
Args:
81+
request: The AG-UI request data.
82+
deps: Optional dependencies to pass to the agent.
83+
output_type: Custom output type for this run.
84+
model: Optional model to use for this run.
85+
model_settings: Optional settings for the model's request.
86+
usage_limits: Optional limits on model request count or token usage.
87+
usage: Optional usage to start with.
88+
infer_name: Whether to infer the agent name from the call frame.
89+
toolsets: Optional additional toolsets for this run.
90+
on_complete: Optional callback called when the agent run completes.
91+
92+
Yields:
93+
AG-UI protocol events (BaseEvent subclasses).
94+
95+
Raises:
96+
_RunError: If request validation fails or other errors occur.
97+
"""
98+
from ... import _utils
99+
from ...exceptions import UserError
100+
from ...tools import DeferredToolRequests
101+
from .event_stream import RunFinishedEvent, RunStartedEvent
102+
103+
# Create event stream
104+
event_stream = self.create_event_stream()
105+
stream_started = False
106+
107+
# Handle frontend tools
108+
if request.tools:
109+
toolset = _AGUIFrontendToolset[AgentDepsT](request.tools)
110+
toolsets = [*toolsets, toolset] if toolsets else [toolset]
111+
112+
try:
113+
# Emit start event
114+
yield RunStartedEvent(
115+
thread_id=request.thread_id,
116+
run_id=request.run_id,
117+
)
118+
stream_started = True
119+
120+
if not request.messages:
121+
raise _NoMessagesError
122+
123+
# Handle state injection
124+
raw_state: dict[str, Any] = request.state or {}
125+
if isinstance(deps, StateHandler):
126+
if isinstance(deps.state, BaseModel):
127+
try:
128+
state = type(deps.state).model_validate(raw_state)
129+
except ValidationError as e: # pragma: no cover
130+
raise _InvalidStateError from e
131+
else:
132+
state = raw_state
133+
134+
from dataclasses import replace
135+
136+
deps = replace(deps, state=state)
137+
elif raw_state:
138+
raise UserError(
139+
f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.'
140+
)
141+
142+
# Convert AG-UI messages to pAI messages
143+
messages = protocol_messages_to_pai_messages(request.messages)
144+
145+
# Run agent and stream events
146+
result = None
147+
async for event in self.agent.run_stream_events(
148+
user_prompt=None,
149+
output_type=[output_type or self.agent.output_type, DeferredToolRequests],
150+
message_history=messages,
151+
model=model,
152+
deps=deps,
153+
model_settings=model_settings,
154+
usage_limits=usage_limits,
155+
usage=usage,
156+
infer_name=infer_name,
157+
toolsets=toolsets,
158+
):
159+
from ...run import AgentRunResultEvent
160+
161+
# Capture result for on_complete callback
162+
if isinstance(event, AgentRunResultEvent):
163+
result = event.result
164+
165+
# Transform pAI events to AG-UI events
166+
async for ag_ui_event in event_stream.agent_event_to_events(event): # type: ignore[arg-type]
167+
yield ag_ui_event
168+
169+
# Call on_complete callback
170+
if on_complete is not None and result is not None:
171+
if _utils.is_async_callable(on_complete):
172+
await on_complete(result)
173+
else:
174+
await _utils.run_in_executor(on_complete, result)
175+
176+
except _RunError as e:
177+
if stream_started:
178+
async for error_event in event_stream.on_stream_error(e):
179+
yield error_event
180+
else:
181+
async for error_event in event_stream.on_validation_error(e):
182+
yield error_event
183+
raise
184+
except Exception as e:
185+
if stream_started:
186+
async for error_event in event_stream.on_stream_error(e):
187+
yield error_event
188+
else:
189+
async for error_event in event_stream.on_validation_error(e):
190+
yield error_event
191+
raise
192+
else:
193+
# Emit finish event
194+
yield RunFinishedEvent(
195+
thread_id=request.thread_id,
196+
run_id=request.run_id,
197+
)
198+
199+
async def run_stream_sse(
200+
self,
201+
request: RunAgentInput,
202+
accept: str,
203+
*,
204+
output_type: Any = None,
205+
model: Any = None,
206+
deps: AgentDepsT | None = None,
207+
model_settings: Any = None,
208+
usage_limits: Any = None,
209+
usage: Any = None,
210+
infer_name: bool = True,
211+
toolsets: Any = None,
212+
on_complete: Any = None,
213+
):
214+
"""Stream SSE-encoded events from an agent run.
215+
216+
This method wraps `run_stream` and encodes the events as SSE strings.
217+
218+
Args:
219+
request: The AG-UI request data.
220+
accept: The accept header value for encoding.
221+
output_type: Custom output type for this run.
222+
model: Optional model to use for this run.
223+
deps: Optional dependencies to pass to the agent.
224+
model_settings: Optional settings for the model's request.
225+
usage_limits: Optional limits on model request count or token usage.
226+
usage: Optional usage to start with.
227+
infer_name: Whether to infer the agent name from the call frame.
228+
toolsets: Optional additional toolsets for this run.
229+
on_complete: Optional callback called when the agent run completes.
230+
231+
Yields:
232+
SSE-formatted strings.
233+
"""
234+
from ag_ui.encoder import EventEncoder
235+
236+
encoder = EventEncoder(accept=accept)
237+
238+
try:
239+
async for event in self.run_stream(
240+
request=request,
241+
deps=deps,
242+
output_type=output_type,
243+
model=model,
244+
model_settings=model_settings,
245+
usage_limits=usage_limits,
246+
usage=usage,
247+
infer_name=infer_name,
248+
toolsets=toolsets,
249+
on_complete=on_complete,
250+
):
251+
yield encoder.encode(event)
252+
except _RunError:
253+
# Error events are already yielded by run_stream
254+
# This shouldn't actually be reached since run_stream yields error events before raising
255+
pass
256+
except Exception:
257+
# Let other exceptions propagate
258+
raise
259+
260+
async def dispatch_request(
261+
self,
262+
request: Any,
263+
deps: AgentDepsT | None = None,
264+
*,
265+
output_type: Any = None,
266+
model: Any = None,
267+
model_settings: Any = None,
268+
usage_limits: Any = None,
269+
usage: Any = None,
270+
infer_name: bool = True,
271+
toolsets: Any = None,
272+
on_complete: Any = None,
273+
) -> Any:
274+
"""Handle an AG-UI request and return a streaming response.
275+
276+
Args:
277+
request: The incoming Starlette/FastAPI request.
278+
deps: Optional dependencies to pass to the agent.
279+
output_type: Custom output type for this run.
280+
model: Optional model to use for this run.
281+
model_settings: Optional settings for the model's request.
282+
usage_limits: Optional limits on model request count or token usage.
283+
usage: Optional usage to start with.
284+
infer_name: Whether to infer the agent name from the call frame.
285+
toolsets: Optional additional toolsets for this run.
286+
on_complete: Optional callback called when the agent run completes.
287+
288+
Returns:
289+
A streaming Starlette response with AG-UI protocol events.
290+
"""
291+
try:
292+
from starlette.requests import Request
293+
from starlette.responses import Response, StreamingResponse
294+
except ImportError as e: # pragma: no cover
295+
raise ImportError('Please install starlette to use dispatch_request') from e
296+
297+
if not isinstance(request, Request): # pragma: no cover
298+
raise TypeError(f'Expected Starlette Request, got {type(request).__name__}')
299+
300+
accept = request.headers.get('accept', 'text/event-stream')
301+
302+
try:
303+
input_data = RunAgentInput.model_validate(await request.json())
304+
except ValidationError as e: # pragma: no cover
305+
return Response(
306+
content=json.dumps(e.json()),
307+
media_type='application/json',
308+
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
309+
)
310+
311+
return StreamingResponse(
312+
self.run_stream_sse(
313+
request=input_data,
314+
accept=accept,
315+
deps=deps,
316+
output_type=output_type,
317+
model=model,
318+
model_settings=model_settings,
319+
usage_limits=usage_limits,
320+
usage=usage,
321+
infer_name=infer_name,
322+
toolsets=toolsets,
323+
on_complete=on_complete,
324+
),
325+
media_type=accept,
326+
)
327+
328+
def create_event_stream(self) -> AGUIEventStream[AgentDepsT]:
329+
"""Create a new AG-UI event stream.
330+
331+
Returns:
332+
An AGUIEventStream instance.
333+
"""
334+
return AGUIEventStream[AgentDepsT]()

0 commit comments

Comments
 (0)