|
7 | 7 |
|
8 | 8 | from __future__ import annotations as _annotations |
9 | 9 |
|
| 10 | +import asyncio |
| 11 | +import json |
| 12 | +import sqlite3 |
| 13 | +from collections.abc import AsyncIterator, Callable |
| 14 | +from concurrent.futures.thread import ThreadPoolExecutor |
10 | 15 | from contextlib import asynccontextmanager |
11 | 16 | from dataclasses import dataclass |
| 17 | +from datetime import datetime, timezone |
| 18 | +from functools import partial |
12 | 19 | from pathlib import Path |
| 20 | +from typing import Annotated, Any, Literal, TypeVar |
13 | 21 |
|
14 | 22 | import fastapi |
15 | 23 | import logfire |
16 | | -from fastapi import Depends, Request, Response |
17 | | - |
18 | | -from pydantic_ai import Agent, RunContext |
19 | | -from pydantic_ai.ui.vercel_ai import VercelAIAdapter |
20 | | - |
21 | | -from .sqlite_database import Database |
| 24 | +from fastapi import Depends, Request |
| 25 | +from fastapi.responses import FileResponse, Response, StreamingResponse |
| 26 | +from typing_extensions import LiteralString, ParamSpec, TypedDict |
| 27 | + |
| 28 | +from pydantic_ai import ( |
| 29 | + Agent, |
| 30 | + ModelMessage, |
| 31 | + ModelMessagesTypeAdapter, |
| 32 | + ModelRequest, |
| 33 | + ModelResponse, |
| 34 | + TextPart, |
| 35 | + UnexpectedModelBehavior, |
| 36 | + UserPromptPart, |
| 37 | +) |
22 | 38 |
|
23 | 39 | # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured |
24 | 40 | logfire.configure(send_to_logfire='if-token-present') |
25 | 41 | logfire.instrument_pydantic_ai() |
26 | 42 |
|
| 43 | +agent = Agent('openai:gpt-4o') |
27 | 44 | THIS_DIR = Path(__file__).parent |
28 | | -sql_schema = """ |
29 | | -create table if not exists memory( |
30 | | - id integer primary key, |
31 | | - user_id integer not null, |
32 | | - value text not null, |
33 | | - unique(user_id, value) |
34 | | -);""" |
35 | 45 |
|
36 | 46 |
|
37 | 47 | @asynccontextmanager |
38 | 48 | async def lifespan(_app: fastapi.FastAPI): |
39 | | - async with Database.connect(sql_schema) as db: |
| 49 | + async with Database.connect() as db: |
40 | 50 | yield {'db': db} |
41 | 51 |
|
42 | 52 |
|
43 | | -@dataclass |
44 | | -class Deps: |
45 | | - conn: Database |
46 | | - user_id: int |
47 | | - |
48 | | - |
49 | | -chat_agent = Agent( |
50 | | - 'openai:gpt-4.1', |
51 | | - deps_type=Deps, |
52 | | - instructions=""" |
53 | | -You are a helpful assistant. |
54 | | -
|
55 | | -Always reply with markdown. ALWAYS use code fences for code examples and lines of code. |
56 | | -""", |
57 | | -) |
58 | | - |
59 | | - |
60 | | -@chat_agent.tool |
61 | | -async def record_memory(ctx: RunContext[Deps], value: str) -> str: |
62 | | - """Use this tool to store information in memory.""" |
63 | | - await ctx.deps.conn.execute( |
64 | | - 'insert into memory(user_id, value) values(?, ?) on conflict do nothing', |
65 | | - ctx.deps.user_id, |
66 | | - value, |
67 | | - commit=True, |
68 | | - ) |
69 | | - return 'Value added to memory.' |
| 53 | +app = fastapi.FastAPI(lifespan=lifespan) |
| 54 | +logfire.instrument_fastapi(app) |
70 | 55 |
|
71 | 56 |
|
72 | | -@chat_agent.tool |
73 | | -async def retrieve_memories(ctx: RunContext[Deps], memory_contains: str) -> str: |
74 | | - """Get all memories about the user.""" |
75 | | - rows = await ctx.deps.conn.fetchall( |
76 | | - 'select value from memory where user_id = ? and value like ?', |
77 | | - ctx.deps.user_id, |
78 | | - f'%{memory_contains}%', |
79 | | - ) |
80 | | - return '\n'.join([row[0] for row in rows]) |
| 57 | +@app.get('/') |
| 58 | +async def index() -> FileResponse: |
| 59 | + return FileResponse((THIS_DIR / 'chat_app.html'), media_type='text/html') |
81 | 60 |
|
82 | 61 |
|
83 | | -app = fastapi.FastAPI(lifespan=lifespan) |
84 | | -logfire.instrument_fastapi(app) |
| 62 | +@app.get('/chat_app.ts') |
| 63 | +async def main_ts() -> FileResponse: |
| 64 | + """Get the raw typescript code, it's compiled in the browser, forgive me.""" |
| 65 | + return FileResponse((THIS_DIR / 'chat_app.ts'), media_type='text/plain') |
85 | 66 |
|
86 | 67 |
|
87 | 68 | async def get_db(request: Request) -> Database: |
88 | 69 | return request.state.db |
89 | 70 |
|
90 | 71 |
|
91 | | -@app.options('/api/chat') |
92 | | -def options_chat(): |
93 | | - pass |
| 72 | +@app.get('/chat/') |
| 73 | +async def get_chat(database: Database = Depends(get_db)) -> Response: |
| 74 | + msgs = await database.get_messages() |
| 75 | + return Response( |
| 76 | + b'\n'.join(json.dumps(to_chat_message(m)).encode('utf-8') for m in msgs), |
| 77 | + media_type='text/plain', |
| 78 | + ) |
94 | 79 |
|
95 | 80 |
|
96 | | -@app.post('/api/chat') |
97 | | -async def get_chat(request: Request, database: Database = Depends(get_db)) -> Response: |
98 | | - return await VercelAIAdapter[Deps].dispatch_request( |
99 | | - request, agent=chat_agent, deps=Deps(database, 123) |
100 | | - ) |
| 81 | +class ChatMessage(TypedDict): |
| 82 | + """Format of messages sent to the browser.""" |
| 83 | + |
| 84 | + role: Literal['user', 'model'] |
| 85 | + timestamp: str |
| 86 | + content: str |
| 87 | + |
| 88 | + |
| 89 | +def to_chat_message(m: ModelMessage) -> ChatMessage: |
| 90 | + first_part = m.parts[0] |
| 91 | + if isinstance(m, ModelRequest): |
| 92 | + if isinstance(first_part, UserPromptPart): |
| 93 | + assert isinstance(first_part.content, str) |
| 94 | + return { |
| 95 | + 'role': 'user', |
| 96 | + 'timestamp': first_part.timestamp.isoformat(), |
| 97 | + 'content': first_part.content, |
| 98 | + } |
| 99 | + elif isinstance(m, ModelResponse): |
| 100 | + if isinstance(first_part, TextPart): |
| 101 | + return { |
| 102 | + 'role': 'model', |
| 103 | + 'timestamp': m.timestamp.isoformat(), |
| 104 | + 'content': first_part.content, |
| 105 | + } |
| 106 | + raise UnexpectedModelBehavior(f'Unexpected message type for chat app: {m}') |
| 107 | + |
| 108 | + |
| 109 | +@app.post('/chat/') |
| 110 | +async def post_chat( |
| 111 | + prompt: Annotated[str, fastapi.Form()], database: Database = Depends(get_db) |
| 112 | +) -> StreamingResponse: |
| 113 | + async def stream_messages(): |
| 114 | + """Streams new line delimited JSON `Message`s to the client.""" |
| 115 | + # stream the user prompt so that can be displayed straight away |
| 116 | + yield ( |
| 117 | + json.dumps( |
| 118 | + { |
| 119 | + 'role': 'user', |
| 120 | + 'timestamp': datetime.now(tz=timezone.utc).isoformat(), |
| 121 | + 'content': prompt, |
| 122 | + } |
| 123 | + ).encode('utf-8') |
| 124 | + + b'\n' |
| 125 | + ) |
| 126 | + # get the chat history so far to pass as context to the agent |
| 127 | + messages = await database.get_messages() |
| 128 | + # run the agent with the user prompt and the chat history |
| 129 | + async with agent.run_stream(prompt, message_history=messages) as result: |
| 130 | + async for text in result.stream_output(debounce_by=0.01): |
| 131 | + # text here is a `str` and the frontend wants |
| 132 | + # JSON encoded ModelResponse, so we create one |
| 133 | + m = ModelResponse(parts=[TextPart(text)], timestamp=result.timestamp()) |
| 134 | + yield json.dumps(to_chat_message(m)).encode('utf-8') + b'\n' |
| 135 | + |
| 136 | + # add new messages (e.g. the user prompt and the agent response in this case) to the database |
| 137 | + await database.add_messages(result.new_messages_json()) |
| 138 | + |
| 139 | + return StreamingResponse(stream_messages(), media_type='text/plain') |
| 140 | + |
| 141 | + |
| 142 | +P = ParamSpec('P') |
| 143 | +R = TypeVar('R') |
| 144 | + |
| 145 | + |
| 146 | +@dataclass |
| 147 | +class Database: |
| 148 | + """Rudimentary database to store chat messages in SQLite. |
| 149 | +
|
| 150 | + The SQLite standard library package is synchronous, so we |
| 151 | + use a thread pool executor to run queries asynchronously. |
| 152 | + """ |
| 153 | + |
| 154 | + con: sqlite3.Connection |
| 155 | + _loop: asyncio.AbstractEventLoop |
| 156 | + _executor: ThreadPoolExecutor |
| 157 | + |
| 158 | + @classmethod |
| 159 | + @asynccontextmanager |
| 160 | + async def connect( |
| 161 | + cls, file: Path = THIS_DIR / '.chat_app_messages.sqlite' |
| 162 | + ) -> AsyncIterator[Database]: |
| 163 | + with logfire.span('connect to DB'): |
| 164 | + loop = asyncio.get_event_loop() |
| 165 | + executor = ThreadPoolExecutor(max_workers=1) |
| 166 | + con = await loop.run_in_executor(executor, cls._connect, file) |
| 167 | + slf = cls(con, loop, executor) |
| 168 | + try: |
| 169 | + yield slf |
| 170 | + finally: |
| 171 | + await slf._asyncify(con.close) |
| 172 | + |
| 173 | + @staticmethod |
| 174 | + def _connect(file: Path) -> sqlite3.Connection: |
| 175 | + con = sqlite3.connect(str(file)) |
| 176 | + con = logfire.instrument_sqlite3(con) |
| 177 | + cur = con.cursor() |
| 178 | + cur.execute( |
| 179 | + 'CREATE TABLE IF NOT EXISTS messages (id INT PRIMARY KEY, message_list TEXT);' |
| 180 | + ) |
| 181 | + con.commit() |
| 182 | + return con |
| 183 | + |
| 184 | + async def add_messages(self, messages: bytes): |
| 185 | + await self._asyncify( |
| 186 | + self._execute, |
| 187 | + 'INSERT INTO messages (message_list) VALUES (?);', |
| 188 | + messages, |
| 189 | + commit=True, |
| 190 | + ) |
| 191 | + await self._asyncify(self.con.commit) |
| 192 | + |
| 193 | + async def get_messages(self) -> list[ModelMessage]: |
| 194 | + c = await self._asyncify( |
| 195 | + self._execute, 'SELECT message_list FROM messages order by id' |
| 196 | + ) |
| 197 | + rows = await self._asyncify(c.fetchall) |
| 198 | + messages: list[ModelMessage] = [] |
| 199 | + for row in rows: |
| 200 | + messages.extend(ModelMessagesTypeAdapter.validate_json(row[0])) |
| 201 | + return messages |
| 202 | + |
| 203 | + def _execute( |
| 204 | + self, sql: LiteralString, *args: Any, commit: bool = False |
| 205 | + ) -> sqlite3.Cursor: |
| 206 | + cur = self.con.cursor() |
| 207 | + cur.execute(sql, args) |
| 208 | + if commit: |
| 209 | + self.con.commit() |
| 210 | + return cur |
| 211 | + |
| 212 | + async def _asyncify( |
| 213 | + self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs |
| 214 | + ) -> R: |
| 215 | + return await self._loop.run_in_executor( # type: ignore |
| 216 | + self._executor, |
| 217 | + partial(func, **kwargs), |
| 218 | + *args, # type: ignore |
| 219 | + ) |
101 | 220 |
|
102 | 221 |
|
103 | 222 | if __name__ == '__main__': |
|
0 commit comments