Skip to content

Commit 6dc3e8d

Browse files
authored
add streaming to chat app (#48)
1 parent 48d2fd1 commit 6dc3e8d

File tree

6 files changed

+147
-69
lines changed

6 files changed

+147
-69
lines changed

docs/examples/chat-app.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ Demonstrates:
88

99
* reusing chat history
1010
* serializing messages
11+
* streaming responses
1112

1213
This demonstrates storing chat history between requests and using it to give the model context for new responses.
1314

14-
Most of the complex logic here is in `chat_app.html` which includes the page layout and JavaScript to handle the chat.
15+
Most of the complex logic here is between `chat_app.py` which streams the response to the browser,
16+
and `chat_app.ts` which renders messages in the browser.
1517

1618
## Running the Example
1719

@@ -27,10 +29,20 @@ TODO screenshot.
2729

2830
## Example Code
2931

32+
Python code that runs the chat app:
33+
3034
```py title="chat_app.py"
3135
#! pydantic_ai_examples/chat_app.py
3236
```
3337

38+
Simple HTML page to render the app:
39+
3440
```html title="chat_app.html"
3541
#! pydantic_ai_examples/chat_app.html
3642
```
43+
44+
TypeScript to handle rendering the messages, to keep this simple (and at the risk of offending frontend developers) the typescript code is passed to the browser as plain text and transpiled in the browser.
45+
46+
```ts title="chat_app.ts"
47+
#! pydantic_ai_examples/chat_app.ts
48+
```

pydantic_ai/result.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from abc import ABC, abstractmethod
44
from collections.abc import AsyncIterator
55
from dataclasses import dataclass
6+
from datetime import datetime
67
from typing import Generic, TypeVar, cast
78

89
import logfire_api
@@ -273,6 +274,10 @@ def cost(self) -> Cost:
273274
"""
274275
return self.cost_so_far + self._stream_response.cost()
275276

277+
def timestamp(self) -> datetime:
278+
"""Get the timestamp of the response."""
279+
return self._stream_response.timestamp()
280+
276281
async def validate_structured_result(
277282
self, message: messages.ModelStructuredResponse, *, allow_partial: bool = False
278283
) -> ResultData:

pydantic_ai_examples/chat_app.html

Lines changed: 15 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -58,71 +58,24 @@ <h1>Chat App</h1>
5858
</main>
5959
</body>
6060
</html>
61+
<script src="https://cdnjs.cloudflare.com/ajax/libs/typescript/5.6.3/typescript.min.js" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
6162
<script type="module">
62-
import { marked } from 'https://cdn.jsdelivr.net/npm/marked/lib/marked.esm.js';
63-
64-
function addMessages(lines) {
65-
const messages = lines.filter(line => line.length > 1).map((line) => JSON.parse(line))
66-
const parent = document.getElementById('conversation');
67-
for (const message of messages) {
68-
let msgDiv = document.createElement('div');
69-
msgDiv.classList.add('border-top', 'pt-2', message.role);
70-
msgDiv.innerHTML = marked.parse(message.content);
71-
parent.appendChild(msgDiv);
72-
}
63+
// to let me write TypeScript, without adding the burden of npm we do a dirty, non-production-ready hack
64+
// and transpile the TypeScript code in the browser
65+
// this is (arguably) A neat demo trick, but not suitable for production!
66+
async function loadTs() {
67+
const response = await fetch('/chat_app.ts');
68+
const tsCode = await response.text();
69+
const jsCode = window.ts.transpile(tsCode, { target: "es2015" });
70+
let script = document.createElement('script');
71+
script.type = 'module';
72+
script.text = jsCode;
73+
document.body.appendChild(script);
7374
}
7475

75-
function onError(error) {
76-
console.error(error);
76+
loadTs().catch((e) => {
77+
console.error(e);
7778
document.getElementById('error').classList.remove('d-none');
7879
document.getElementById('spinner').classList.remove('active');
79-
}
80-
81-
async function fetchResponse(response) {
82-
let text = '';
83-
if (response.ok) {
84-
const reader = response.body.getReader();
85-
while (true) {
86-
const {done, value} = await reader.read();
87-
if (done) {
88-
break;
89-
}
90-
text += new TextDecoder().decode(value);
91-
const lines = text.split('\n');
92-
if (lines.length > 1) {
93-
addMessages(lines.slice(0, -1));
94-
text = lines[lines.length - 1];
95-
}
96-
}
97-
addMessages(text.split('\n'));
98-
let input = document.getElementById('prompt-input')
99-
input.disabled = false;
100-
input.focus();
101-
} else {
102-
const text = await response.text();
103-
console.error(`Unexpected response: ${response.status}`, {response, text});
104-
throw new Error(`Unexpected response: ${response.status}`);
105-
}
106-
}
107-
108-
async function onSubmit(e) {
109-
e.preventDefault();
110-
const spinner = document.getElementById('spinner');
111-
spinner.classList.add('active');
112-
const body = new FormData(e.target);
113-
114-
let input = document.getElementById('prompt-input')
115-
input.value = '';
116-
input.disabled = true;
117-
118-
const response = await fetch('/chat/', {method: 'POST', body});
119-
await fetchResponse(response);
120-
spinner.classList.remove('active');
121-
}
122-
123-
// call onSubmit when form is submitted (e.g. user clicks the send button or hits Enter)
124-
document.querySelector('form').addEventListener('submit', (e) => onSubmit(e).catch(onError));
125-
126-
// load messages on page load
127-
fetch('/chat/').then(fetchResponse).catch(onError);
80+
});
12881
</script>

pydantic_ai_examples/chat_app.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
from pydantic import Field, TypeAdapter
1717

1818
from pydantic_ai import Agent
19-
from pydantic_ai.messages import Message, MessagesTypeAdapter, UserPrompt
19+
from pydantic_ai.messages import (
20+
Message,
21+
MessagesTypeAdapter,
22+
ModelTextResponse,
23+
UserPrompt,
24+
)
2025

2126
# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
2227
logfire.configure(send_to_logfire='if-token-present')
@@ -32,6 +37,12 @@ async def index() -> HTMLResponse:
3237
return HTMLResponse((THIS_DIR / 'chat_app.html').read_bytes())
3338

3439

40+
@app.get('/chat_app.ts')
41+
async def main_ts() -> Response:
42+
"""Get the raw typescript code, it's compiled in the browser, forgive me."""
43+
return Response((THIS_DIR / 'chat_app.ts').read_bytes(), media_type='text/plain')
44+
45+
3546
@app.get('/chat/')
3647
async def get_chat() -> Response:
3748
msgs = database.get_messages()
@@ -49,12 +60,16 @@ async def stream_messages():
4960
yield MessageTypeAdapter.dump_json(UserPrompt(content=prompt)) + b'\n'
5061
# get the chat history so far to pass as context to the agent
5162
messages = list(database.get_messages())
52-
response = await agent.run(prompt, message_history=messages)
63+
# run the agent with the user prompt and the chat history
64+
async with agent.run_stream(prompt, message_history=messages) as result:
65+
async for text in result.stream(debounce_by=0.01):
66+
# text here is a `str` and the frontend wants
67+
# JSON encoded ModelTextResponse, so we create one
68+
m = ModelTextResponse(content=text, timestamp=result.timestamp())
69+
yield MessageTypeAdapter.dump_json(m) + b'\n'
70+
5371
# add new messages (e.g. the user prompt and the agent response in this case) to the database
54-
database.add_messages(response.new_messages_json())
55-
# stream the last message which will be the agent response, we can't just yield `new_messages_json()`
56-
# since we already stream the user prompt
57-
yield MessageTypeAdapter.dump_json(response.all_messages()[-1]) + b'\n'
72+
database.add_messages(result.new_messages_json())
5873

5974
return StreamingResponse(stream_messages(), media_type='text/plain')
6075

pydantic_ai_examples/chat_app.ts

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// BIG FAT WARNING: to avoid the complexity of npm, this typescript is compiled in the browser
2+
// there's currently no static type checking
3+
4+
import { marked } from 'https://cdnjs.cloudflare.com/ajax/libs/marked/15.0.0/lib/marked.esm.js'
5+
const convElement = document.getElementById('conversation')
6+
7+
const promptInput = document.getElementById('prompt-input') as HTMLInputElement
8+
const spinner = document.getElementById('spinner')
9+
10+
// stream the response and render messages as each chunk is received
11+
// data is sent as newline-delimited JSON
12+
async function onFetchResponse(response: Response): Promise<void> {
13+
let text = ''
14+
let decoder = new TextDecoder()
15+
if (response.ok) {
16+
const reader = response.body.getReader()
17+
while (true) {
18+
const {done, value} = await reader.read()
19+
if (done) {
20+
break
21+
}
22+
text += decoder.decode(value)
23+
addMessages(text)
24+
spinner.classList.remove('active')
25+
}
26+
addMessages(text)
27+
promptInput.disabled = false
28+
promptInput.focus()
29+
} else {
30+
const text = await response.text()
31+
console.error(`Unexpected response: ${response.status}`, {response, text})
32+
throw new Error(`Unexpected response: ${response.status}`)
33+
}
34+
}
35+
36+
// The format of messages, this matches pydantic-ai both for brevity and understanding
37+
// in production, you might not want to keep this format all the way to the frontend
38+
interface Message {
39+
role: string
40+
content: string
41+
timestamp: string
42+
}
43+
44+
// take raw response text and render messages into the `#conversation` element
45+
// Message timestamp is assumed to be a unique identifier of a message, and is used to deduplicate
46+
// hence you can send data about the same message multiple times, and it will be updated
47+
// instead of creating a new message elements
48+
function addMessages(responseText: string) {
49+
const lines = responseText.split('\n')
50+
const messages: Message[] = lines.filter(line => line.length > 1).map(j => JSON.parse(j))
51+
for (const message of messages) {
52+
// we use the timestamp as a crude element id
53+
const {timestamp, role, content} = message
54+
const id = `msg-${timestamp}`
55+
let msgDiv = document.getElementById(id)
56+
if (!msgDiv) {
57+
msgDiv = document.createElement('div')
58+
msgDiv.id = id
59+
msgDiv.title = `${role} at ${timestamp}`
60+
msgDiv.classList.add('border-top', 'pt-2', role)
61+
convElement.appendChild(msgDiv)
62+
}
63+
msgDiv.innerHTML = marked.parse(content)
64+
}
65+
window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' })
66+
}
67+
68+
function onError(error: any) {
69+
console.error(error)
70+
document.getElementById('error').classList.remove('d-none')
71+
document.getElementById('spinner').classList.remove('active')
72+
}
73+
74+
async function onSubmit(e: SubmitEvent): Promise<void> {
75+
e.preventDefault()
76+
spinner.classList.add('active')
77+
const body = new FormData(e.target as HTMLFormElement)
78+
79+
promptInput.value = ''
80+
promptInput.disabled = true
81+
82+
const response = await fetch('/chat/', {method: 'POST', body})
83+
await onFetchResponse(response)
84+
}
85+
86+
// call onSubmit when the form is submitted (e.g. user clicks the send button or hits Enter)
87+
document.querySelector('form').addEventListener('submit', (e) => onSubmit(e).catch(onError))
88+
89+
// load messages on page load
90+
fetch('/chat/').then(onFetchResponse).catch(onError)

tests/test_streaming.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
2323
from pydantic_ai.models.test import TestModel
24+
from pydantic_ai.result import Cost
2425
from tests.conftest import IsNow
2526

2627
pytestmark = pytest.mark.anyio
@@ -51,6 +52,8 @@ async def ret_a(x: str) -> str:
5152
response = await result.get_data()
5253
assert response == snapshot('{"ret_a":"a-apple"}')
5354
assert result.is_complete
55+
assert result.cost() == snapshot(Cost())
56+
assert result.timestamp() == IsNow(tz=timezone.utc)
5457
assert result.all_messages() == snapshot(
5558
[
5659
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),

0 commit comments

Comments
 (0)