Skip to content

Commit bdd321d

Browse files
committed
refactoring
1 parent e39612d commit bdd321d

File tree

11 files changed

+1767
-101
lines changed

11 files changed

+1767
-101
lines changed

examples/pydantic_ai_examples/chat_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from fastapi import Depends, Request, Response
1717

1818
from pydantic_ai import Agent, RunContext
19-
from pydantic_ai.vercel_ai_elements.starlette import StarletteChat
19+
from pydantic_ai.vercel_ai.starlette import StarletteChat
2020

2121
from .sqlite_database import Database
2222

pydantic_ai_slim/pydantic_ai/vercel_ai_elements/request_types.py renamed to pydantic_ai_slim/pydantic_ai/vercel_ai/request_types.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,16 @@ class UIMessage(CamelBaseModel):
249249
"""
250250

251251

252-
class SubmitMessage(CamelBaseModel):
252+
class SubmitMessage(CamelBaseModel, extra='allow'):
253253
"""Submit a message to the agent."""
254254

255255
trigger: Literal['submit-message']
256256
id: str
257257
messages: list[UIMessage]
258258

259-
model: str
260-
web_search: bool
259+
# TODO (DouweM): Update, make variable? I like `inference_params` from OpenAI ChatKit.
260+
# model: str
261+
# web_search: bool
261262

262263

263264
class RegenerateMessage(CamelBaseModel):
@@ -269,5 +270,5 @@ class RegenerateMessage(CamelBaseModel):
269270
message_id: str
270271

271272

272-
RequestData = SubmitMessage | RegenerateMessage
273-
request_data_schema: TypeAdapter[RequestData] = TypeAdapter(Annotated[RequestData, Discriminator('trigger')])
273+
RequestData = Annotated[SubmitMessage | RegenerateMessage, Discriminator('trigger')]
274+
request_data_ta: TypeAdapter[RequestData] = TypeAdapter(RequestData)

pydantic_ai_slim/pydantic_ai/vercel_ai_elements/response_stream.py renamed to pydantic_ai_slim/pydantic_ai/vercel_ai/response_stream.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,12 @@
88
from pydantic_core import to_json
99

1010
from .. import messages
11-
from ..agent import Agent
12-
from ..run import AgentRunResultEvent
13-
from ..tools import AgentDepsT
1411
from . import response_types as _t
1512

16-
__all__ = 'sse_stream', 'VERCEL_AI_ELEMENTS_HEADERS', 'EventStreamer'
17-
# no idea if this is important, but vercel sends it, therefore so am I
18-
VERCEL_AI_ELEMENTS_HEADERS = {'x-vercel-ai-ui-message-stream': 'v1'}
13+
__all__ = 'VERCEL_AI_DSP_HEADERS', 'EventStreamer'
1914

20-
21-
async def sse_stream(agent: Agent[AgentDepsT], user_prompt: str, deps: Any) -> AsyncIterator[str]:
22-
"""Stream events from an agent run as Vercel AI Elements events.
23-
24-
Args:
25-
agent: The agent to run.
26-
user_prompt: The user prompt to run the agent with.
27-
deps: The dependencies to pass to the agent.
28-
29-
Yields:
30-
An async iterator text lines to stream over SSE.
31-
"""
32-
event_streamer = EventStreamer()
33-
async for event in agent.run_stream_events(user_prompt, deps=deps):
34-
if not isinstance(event, AgentRunResultEvent):
35-
async for chunk in event_streamer.event_to_chunks(event):
36-
yield chunk.sse()
37-
async for chunk in event_streamer.finish():
38-
yield chunk.sse()
15+
# See https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#data-stream-protocol
16+
VERCEL_AI_DSP_HEADERS = {'x-vercel-ai-ui-message-stream': 'v1'}
3917

4018

4119
@dataclass
@@ -136,6 +114,9 @@ def sse(self) -> str:
136114
def __str__(self) -> str:
137115
return 'DoneChunk<marker for the end of sse stream message>'
138116

117+
def __eq__(self, other: Any) -> bool:
118+
return isinstance(other, DoneChunk)
119+
139120

140121
def _json_dumps(obj: Any) -> str:
141122
return to_json(obj).decode('utf-8')
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from collections.abc import AsyncIterator
2+
from dataclasses import dataclass
3+
from typing import Generic
4+
5+
from pydantic import ValidationError
6+
7+
from ..agent import Agent
8+
from ..run import AgentRunResultEvent
9+
from ..tools import AgentDepsT
10+
from .request_types import RequestData, TextUIPart, request_data_ta
11+
from .response_stream import VERCEL_AI_DSP_HEADERS, DoneChunk, EventStreamer
12+
from .response_types import AbstractSSEChunk
13+
14+
try:
15+
from sse_starlette.sse import EventSourceResponse
16+
from starlette.requests import Request
17+
from starlette.responses import JSONResponse, Response
18+
except ImportError as e:
19+
raise ImportError('To use Vercel AI Elements, please install starlette and sse_starlette') from e
20+
21+
22+
@dataclass
23+
class StarletteChat(Generic[AgentDepsT]):
24+
"""Starlette support for Pydantic AI's Vercel AI Elements integration.
25+
26+
This can be used with either FastAPI or Starlette apps.
27+
"""
28+
29+
agent: Agent[AgentDepsT]
30+
31+
async def dispatch_request(self, request: Request, deps: AgentDepsT) -> Response:
32+
"""Handle a request and return a streamed SSE response.
33+
34+
Args:
35+
request: The incoming Starlette/FastAPI request.
36+
deps: The dependencies for the agent.
37+
38+
Returns:
39+
A streamed SSE response.
40+
"""
41+
try:
42+
data = request_data_ta.validate_json(await request.json())
43+
44+
async def run_sse() -> AsyncIterator[str]:
45+
async for chunk in self.run(data, deps=deps):
46+
yield chunk.sse()
47+
48+
return EventSourceResponse(run_sse(), headers=VERCEL_AI_DSP_HEADERS)
49+
except ValidationError as e:
50+
return JSONResponse({'errors': e.errors()}, status_code=422)
51+
except Exception as e:
52+
return JSONResponse({'errors': str(e)}, status_code=500)
53+
54+
async def run(self, data: RequestData, deps: AgentDepsT = None) -> AsyncIterator[AbstractSSEChunk | DoneChunk]:
55+
"""Stream events from an agent run as Vercel AI Elements events.
56+
57+
Args:
58+
data: The data to run the agent with.
59+
deps: The dependencies to pass to the agent.
60+
61+
Yields:
62+
An async iterator text lines to stream over SSE.
63+
"""
64+
# TODO (DouweM): Use .model and .builtin_tools
65+
66+
# TODO: Use entire message history
67+
68+
if not data.messages:
69+
raise ValueError('no messages provided')
70+
71+
message = data.messages[-1]
72+
prompt: list[str] = []
73+
for part in message.parts:
74+
if isinstance(part, TextUIPart):
75+
prompt.append(part.text)
76+
else:
77+
raise ValueError(f'Only text parts are supported yet, got {part}')
78+
79+
event_streamer = EventStreamer()
80+
async for event in self.agent.run_stream_events('\n'.join(prompt), deps=deps):
81+
if not isinstance(event, AgentRunResultEvent):
82+
async for chunk in event_streamer.event_to_chunks(event):
83+
yield chunk
84+
async for chunk in event_streamer.finish():
85+
yield chunk

pydantic_ai_slim/pydantic_ai/vercel_ai_elements/starlette.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

0 commit comments

Comments
 (0)