Skip to content

Commit d834583

Browse files
committed
Reset chat app example
1 parent f8be256 commit d834583

File tree

2 files changed

+179
-142
lines changed

2 files changed

+179
-142
lines changed

examples/pydantic_ai_examples/chat_app.py

Lines changed: 179 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,97 +7,216 @@
77

88
from __future__ import annotations as _annotations
99

10+
import asyncio
11+
import json
12+
import sqlite3
13+
from collections.abc import AsyncIterator, Callable
14+
from concurrent.futures.thread import ThreadPoolExecutor
1015
from contextlib import asynccontextmanager
1116
from dataclasses import dataclass
17+
from datetime import datetime, timezone
18+
from functools import partial
1219
from pathlib import Path
20+
from typing import Annotated, Any, Literal, TypeVar
1321

1422
import fastapi
1523
import logfire
16-
from fastapi import Depends, Request, Response
17-
18-
from pydantic_ai import Agent, RunContext
19-
from pydantic_ai.ui.vercel_ai import VercelAIAdapter
20-
21-
from .sqlite_database import Database
24+
from fastapi import Depends, Request
25+
from fastapi.responses import FileResponse, Response, StreamingResponse
26+
from typing_extensions import LiteralString, ParamSpec, TypedDict
27+
28+
from pydantic_ai import (
29+
Agent,
30+
ModelMessage,
31+
ModelMessagesTypeAdapter,
32+
ModelRequest,
33+
ModelResponse,
34+
TextPart,
35+
UnexpectedModelBehavior,
36+
UserPromptPart,
37+
)
2238

2339
# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
2440
logfire.configure(send_to_logfire='if-token-present')
2541
logfire.instrument_pydantic_ai()
2642

43+
agent = Agent('openai:gpt-4o')
2744
THIS_DIR = Path(__file__).parent
28-
sql_schema = """
29-
create table if not exists memory(
30-
id integer primary key,
31-
user_id integer not null,
32-
value text not null,
33-
unique(user_id, value)
34-
);"""
3545

3646

3747
@asynccontextmanager
3848
async def lifespan(_app: fastapi.FastAPI):
39-
async with Database.connect(sql_schema) as db:
49+
async with Database.connect() as db:
4050
yield {'db': db}
4151

4252

43-
@dataclass
44-
class Deps:
45-
conn: Database
46-
user_id: int
47-
48-
49-
chat_agent = Agent(
50-
'openai:gpt-4.1',
51-
deps_type=Deps,
52-
instructions="""
53-
You are a helpful assistant.
54-
55-
Always reply with markdown. ALWAYS use code fences for code examples and lines of code.
56-
""",
57-
)
58-
59-
60-
@chat_agent.tool
61-
async def record_memory(ctx: RunContext[Deps], value: str) -> str:
62-
"""Use this tool to store information in memory."""
63-
await ctx.deps.conn.execute(
64-
'insert into memory(user_id, value) values(?, ?) on conflict do nothing',
65-
ctx.deps.user_id,
66-
value,
67-
commit=True,
68-
)
69-
return 'Value added to memory.'
53+
app = fastapi.FastAPI(lifespan=lifespan)
54+
logfire.instrument_fastapi(app)
7055

7156

72-
@chat_agent.tool
73-
async def retrieve_memories(ctx: RunContext[Deps], memory_contains: str) -> str:
74-
"""Get all memories about the user."""
75-
rows = await ctx.deps.conn.fetchall(
76-
'select value from memory where user_id = ? and value like ?',
77-
ctx.deps.user_id,
78-
f'%{memory_contains}%',
79-
)
80-
return '\n'.join([row[0] for row in rows])
57+
@app.get('/')
58+
async def index() -> FileResponse:
59+
return FileResponse((THIS_DIR / 'chat_app.html'), media_type='text/html')
8160

8261

83-
app = fastapi.FastAPI(lifespan=lifespan)
84-
logfire.instrument_fastapi(app)
62+
@app.get('/chat_app.ts')
63+
async def main_ts() -> FileResponse:
64+
"""Get the raw typescript code, it's compiled in the browser, forgive me."""
65+
return FileResponse((THIS_DIR / 'chat_app.ts'), media_type='text/plain')
8566

8667

8768
async def get_db(request: Request) -> Database:
8869
return request.state.db
8970

9071

91-
@app.options('/api/chat')
92-
def options_chat():
93-
pass
72+
@app.get('/chat/')
73+
async def get_chat(database: Database = Depends(get_db)) -> Response:
74+
msgs = await database.get_messages()
75+
return Response(
76+
b'\n'.join(json.dumps(to_chat_message(m)).encode('utf-8') for m in msgs),
77+
media_type='text/plain',
78+
)
9479

9580

96-
@app.post('/api/chat')
97-
async def get_chat(request: Request, database: Database = Depends(get_db)) -> Response:
98-
return await VercelAIAdapter[Deps].dispatch_request(
99-
request, agent=chat_agent, deps=Deps(database, 123)
100-
)
81+
class ChatMessage(TypedDict):
82+
"""Format of messages sent to the browser."""
83+
84+
role: Literal['user', 'model']
85+
timestamp: str
86+
content: str
87+
88+
89+
def to_chat_message(m: ModelMessage) -> ChatMessage:
90+
first_part = m.parts[0]
91+
if isinstance(m, ModelRequest):
92+
if isinstance(first_part, UserPromptPart):
93+
assert isinstance(first_part.content, str)
94+
return {
95+
'role': 'user',
96+
'timestamp': first_part.timestamp.isoformat(),
97+
'content': first_part.content,
98+
}
99+
elif isinstance(m, ModelResponse):
100+
if isinstance(first_part, TextPart):
101+
return {
102+
'role': 'model',
103+
'timestamp': m.timestamp.isoformat(),
104+
'content': first_part.content,
105+
}
106+
raise UnexpectedModelBehavior(f'Unexpected message type for chat app: {m}')
107+
108+
109+
@app.post('/chat/')
110+
async def post_chat(
111+
prompt: Annotated[str, fastapi.Form()], database: Database = Depends(get_db)
112+
) -> StreamingResponse:
113+
async def stream_messages():
114+
"""Streams new line delimited JSON `Message`s to the client."""
115+
# stream the user prompt so that can be displayed straight away
116+
yield (
117+
json.dumps(
118+
{
119+
'role': 'user',
120+
'timestamp': datetime.now(tz=timezone.utc).isoformat(),
121+
'content': prompt,
122+
}
123+
).encode('utf-8')
124+
+ b'\n'
125+
)
126+
# get the chat history so far to pass as context to the agent
127+
messages = await database.get_messages()
128+
# run the agent with the user prompt and the chat history
129+
async with agent.run_stream(prompt, message_history=messages) as result:
130+
async for text in result.stream_output(debounce_by=0.01):
131+
# text here is a `str` and the frontend wants
132+
# JSON encoded ModelResponse, so we create one
133+
m = ModelResponse(parts=[TextPart(text)], timestamp=result.timestamp())
134+
yield json.dumps(to_chat_message(m)).encode('utf-8') + b'\n'
135+
136+
# add new messages (e.g. the user prompt and the agent response in this case) to the database
137+
await database.add_messages(result.new_messages_json())
138+
139+
return StreamingResponse(stream_messages(), media_type='text/plain')
140+
141+
142+
P = ParamSpec('P')
143+
R = TypeVar('R')
144+
145+
146+
@dataclass
147+
class Database:
148+
"""Rudimentary database to store chat messages in SQLite.
149+
150+
The SQLite standard library package is synchronous, so we
151+
use a thread pool executor to run queries asynchronously.
152+
"""
153+
154+
con: sqlite3.Connection
155+
_loop: asyncio.AbstractEventLoop
156+
_executor: ThreadPoolExecutor
157+
158+
@classmethod
159+
@asynccontextmanager
160+
async def connect(
161+
cls, file: Path = THIS_DIR / '.chat_app_messages.sqlite'
162+
) -> AsyncIterator[Database]:
163+
with logfire.span('connect to DB'):
164+
loop = asyncio.get_event_loop()
165+
executor = ThreadPoolExecutor(max_workers=1)
166+
con = await loop.run_in_executor(executor, cls._connect, file)
167+
slf = cls(con, loop, executor)
168+
try:
169+
yield slf
170+
finally:
171+
await slf._asyncify(con.close)
172+
173+
@staticmethod
174+
def _connect(file: Path) -> sqlite3.Connection:
175+
con = sqlite3.connect(str(file))
176+
con = logfire.instrument_sqlite3(con)
177+
cur = con.cursor()
178+
cur.execute(
179+
'CREATE TABLE IF NOT EXISTS messages (id INT PRIMARY KEY, message_list TEXT);'
180+
)
181+
con.commit()
182+
return con
183+
184+
async def add_messages(self, messages: bytes):
185+
await self._asyncify(
186+
self._execute,
187+
'INSERT INTO messages (message_list) VALUES (?);',
188+
messages,
189+
commit=True,
190+
)
191+
await self._asyncify(self.con.commit)
192+
193+
async def get_messages(self) -> list[ModelMessage]:
194+
c = await self._asyncify(
195+
self._execute, 'SELECT message_list FROM messages order by id'
196+
)
197+
rows = await self._asyncify(c.fetchall)
198+
messages: list[ModelMessage] = []
199+
for row in rows:
200+
messages.extend(ModelMessagesTypeAdapter.validate_json(row[0]))
201+
return messages
202+
203+
def _execute(
204+
self, sql: LiteralString, *args: Any, commit: bool = False
205+
) -> sqlite3.Cursor:
206+
cur = self.con.cursor()
207+
cur.execute(sql, args)
208+
if commit:
209+
self.con.commit()
210+
return cur
211+
212+
async def _asyncify(
213+
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
214+
) -> R:
215+
return await self._loop.run_in_executor( # type: ignore
216+
self._executor,
217+
partial(func, **kwargs),
218+
*args, # type: ignore
219+
)
101220

102221

103222
if __name__ == '__main__':

examples/pydantic_ai_examples/sqlite_database.py

Lines changed: 0 additions & 82 deletions
This file was deleted.

0 commit comments

Comments
 (0)