8
8
from __future__ import annotations as _annotations
9
9
10
10
import asyncio
11
+ import json
11
12
import sqlite3
12
13
from collections .abc import AsyncIterator
13
14
from concurrent .futures .thread import ThreadPoolExecutor
14
15
from contextlib import asynccontextmanager
15
16
from dataclasses import dataclass
17
+ from datetime import datetime , timezone
16
18
from functools import partial
17
19
from pathlib import Path
18
- from typing import Annotated , Any , Callable , TypeVar
20
+ from typing import Annotated , Any , Callable , Literal , TypeVar
19
21
20
22
import fastapi
21
23
import logfire
22
24
from fastapi import Depends , Request
23
25
from fastapi .responses import HTMLResponse , Response , StreamingResponse
24
26
from pydantic import Field , TypeAdapter
25
- from typing_extensions import LiteralString , ParamSpec
27
+ from typing_extensions import LiteralString , ParamSpec , TypedDict
26
28
27
29
from pydantic_ai import Agent
30
+ from pydantic_ai .exceptions import UnexpectedModelBehavior
28
31
from pydantic_ai .messages import (
29
32
Message ,
30
33
MessagesTypeAdapter ,
31
34
ModelResponse ,
35
+ TextPart ,
32
36
UserPrompt ,
33
37
)
34
38
@@ -68,19 +72,54 @@ async def get_db(request: Request) -> Database:
68
72
async def get_chat (database : Database = Depends (get_db )) -> Response :
69
73
msgs = await database .get_messages ()
70
74
return Response (
71
- b'\n ' .join (MessageTypeAdapter . dump_json ( m ) for m in msgs ),
75
+ b'\n ' .join (json . dumps ( to_chat_message ( m )). encode ( 'utf-8' ) for m in msgs ),
72
76
media_type = 'text/plain' ,
73
77
)
74
78
75
79
80
+ class ChatMessage (TypedDict ):
81
+ """Format of messages sent to the browser."""
82
+
83
+ role : Literal ['user' , 'model' ]
84
+ timestamp : str
85
+ content : str
86
+
87
+
88
+ def to_chat_message (m : Message ) -> ChatMessage :
89
+ if isinstance (m , UserPrompt ):
90
+ return {
91
+ 'role' : 'user' ,
92
+ 'timestamp' : m .timestamp .isoformat (),
93
+ 'content' : m .content ,
94
+ }
95
+ elif isinstance (m , ModelResponse ):
96
+ first_part = m .parts [0 ]
97
+ if isinstance (first_part , TextPart ):
98
+ return {
99
+ 'role' : 'model' ,
100
+ 'timestamp' : m .timestamp .isoformat (),
101
+ 'content' : first_part .content ,
102
+ }
103
+ raise UnexpectedModelBehavior (f'Unexpected message type for chat app: { m } ' )
104
+
105
+
76
106
@app .post ('/chat/' )
77
107
async def post_chat (
78
108
prompt : Annotated [str , fastapi .Form ()], database : Database = Depends (get_db )
79
109
) -> StreamingResponse :
80
110
async def stream_messages ():
81
111
"""Streams new line delimited JSON `Message`s to the client."""
82
112
# stream the user prompt so that can be displayed straight away
83
- yield MessageTypeAdapter .dump_json (UserPrompt (content = prompt )) + b'\n '
113
+ yield (
114
+ json .dumps (
115
+ {
116
+ 'role' : 'user' ,
117
+ 'timestamp' : datetime .now (tz = timezone .utc ).isoformat (),
118
+ 'content' : prompt ,
119
+ }
120
+ ).encode ('utf-8' )
121
+ + b'\n '
122
+ )
84
123
# get the chat history so far to pass as context to the agent
85
124
messages = await database .get_messages ()
86
125
# run the agent with the user prompt and the chat history
@@ -89,7 +128,7 @@ async def stream_messages():
89
128
# text here is a `str` and the frontend wants
90
129
# JSON encoded ModelResponse, so we create one
91
130
m = ModelResponse .from_text (content = text , timestamp = result .timestamp ())
92
- yield MessageTypeAdapter . dump_json ( m ) + b'\n '
131
+ yield json . dumps ( to_chat_message ( m )). encode ( 'utf-8' ) + b'\n '
93
132
94
133
# add new messages (e.g. the user prompt and the agent response in this case) to the database
95
134
await database .add_messages (result .new_messages_json ())
0 commit comments