Skip to content

Commit fc1967d

Browse files
committed
use stream event handler in b4
1 parent e8aab58 commit fc1967d

File tree

5 files changed

+84
-69
lines changed

5 files changed

+84
-69
lines changed
Binary file not shown.

src/quartapp/chat.py

Lines changed: 62 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Microsoft. All rights reserved.
22
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
33

4-
from typing import Any, AsyncGenerator, Optional, Tuple
4+
from typing import AsyncGenerator, Dict, Optional, Tuple
55
from quart import Blueprint, jsonify, request, Response, render_template, current_app
66

77
import asyncio
@@ -12,19 +12,54 @@
1212
from azure.identity import DefaultAzureCredential
1313

1414
from azure.ai.projects.models import (
15-
MessageDeltaTextContent,
1615
MessageDeltaChunk,
1716
ThreadMessage,
1817
FileSearchTool,
1918
AsyncToolSet,
2019
FilePurpose,
2120
ThreadMessage,
22-
ThreadError,
2321
StreamEventData,
24-
AgentStreamEvent
22+
AsyncAgentEventHandler,
23+
Agent,
24+
VectorStore
2525
)
2626

27-
bp = Blueprint("chat", __name__, template_folder="templates", static_folder="static")
27+
class ChatBlueprint(Blueprint):
28+
ai_client: AIProjectClient
29+
agent: Agent
30+
files: Dict[str, str]
31+
vector_store: VectorStore
32+
33+
bp = ChatBlueprint("chat", __name__, template_folder="templates", static_folder="static")
34+
35+
class MyEventHandler(AsyncAgentEventHandler[str]):
36+
37+
async def on_message_delta(
38+
self, delta: "MessageDeltaChunk"
39+
) -> Optional[str]:
40+
stream_data = json.dumps({'content': delta.text, 'type': "message"})
41+
return f"data: {stream_data}\n\n"
42+
43+
async def on_thread_message(
44+
self, message: "ThreadMessage"
45+
) -> Optional[str]:
46+
if message.status == "completed":
47+
annotations = [annotation.as_dict() for annotation in message.file_citation_annotations]
48+
stream_data = json.dumps({'content': message.text_messages[0].text.value, 'annotations': annotations, 'type': "completed_message"})
49+
return f"data: {stream_data}\n\n"
50+
return None
51+
52+
async def on_error(self, data: str) -> Optional[str]:
53+
print(f"An error occurred. Data: {data}")
54+
stream_data = json.dumps({'type': "stream_end"})
55+
return f"data: {stream_data}\n\n"
56+
57+
async def on_done(
58+
self,
59+
) -> Optional[str]:
60+
stream_data = json.dumps({'type': "stream_end"})
61+
return f"data: {stream_data}\n\n"
62+
2863

2964

3065
@bp.before_app_serving
@@ -36,15 +71,15 @@ async def start_server():
3671
)
3772

3873
# TODO: add more files are not supported for citation at the moment
39-
files = ["product_info_1.md"]
40-
file_ids = []
41-
for file in files:
42-
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'files', file))
74+
file_names = ["product_info_1.md", "product_info_2.md"]
75+
files: Dict[str, str] = {}
76+
for file_name in file_names:
77+
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'files', file_name))
4378
print(f"Uploading file {file_path}")
4479
file = await ai_client.agents.upload_file_and_poll(file_path=file_path, purpose=FilePurpose.AGENTS)
45-
file_ids.append(file.id)
80+
files.update({file.id: file_path})
4681

47-
vector_store = await ai_client.agents.create_vector_store(file_ids=file_ids, name="sample_store")
82+
vector_store = await ai_client.agents.create_vector_store_and_poll(file_ids=list(files.keys()), name="sample_store")
4883

4984
file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id])
5085

@@ -62,12 +97,12 @@ async def start_server():
6297
bp.ai_client = ai_client
6398
bp.agent = agent
6499
bp.vector_store = vector_store
65-
bp.file_ids = file_ids
100+
bp.files = files
66101

67102

68103
@bp.after_app_serving
69104
async def stop_server():
70-
for file_id in bp.file_ids:
105+
for file_id in bp.files.keys():
71106
await bp.ai_client.agents.delete_file(file_id)
72107
print(f"Deleted file {file_id}")
73108

@@ -81,49 +116,31 @@ async def stop_server():
81116
await bp.ai_client.close()
82117
print("Closed AIProjectClient")
83118

84-
async def yield_callback(event_type: str, event_obj: StreamEventData, **kwargs) -> Optional[str]:
85-
accumulated_text = kwargs['accumulated_text']
86-
if (isinstance(event_obj, MessageDeltaTextContent)):
87-
text_value = event_obj.text.value if event_obj.text else "No text"
88-
stream_data = json.dumps({'content': text_value, 'type': "message"})
89-
accumulated_text[0] += text_value
90-
return f"data: {stream_data}\n\n"
91-
elif isinstance(event_obj, ThreadMessage):
92-
if (event_obj.status == "completed"):
93-
stream_data = json.dumps({'content': accumulated_text[0], 'type': "completed_message"})
94-
return f"data: {stream_data}\n\n"
95-
elif isinstance(event_obj, ThreadError):
96-
print(f"An error occurred. Data: {event_obj.error}")
97-
stream_data = json.dumps({'type': "stream_end"})
98-
return f"data: {stream_data}\n\n"
99-
elif event_type == AgentStreamEvent.DONE:
100-
stream_data = json.dumps({'type': "stream_end"})
101-
return f"data: {stream_data}\n\n"
102-
103-
return None
119+
120+
121+
104122
@bp.get("/")
105123
async def index():
106124
return await render_template("index.html")
107125

108126

109127

110-
async def get_result(thread_id: str, agent_id: str):
111-
112-
accumulated_text = [""]
113-
128+
async def get_result(thread_id: str, agent_id: str) -> AsyncGenerator[str, None]:
114129
async with await bp.ai_client.agents.create_stream(
115130
thread_id=thread_id, assistant_id=agent_id,
131+
event_handler=MyEventHandler()
116132
) as stream:
117-
async for to_be_yield in stream.yield_until_done(yield_callback, accumulated_text=accumulated_text):
118-
yield to_be_yield
133+
async for _, _, to_be_yield in stream:
134+
if to_be_yield:
135+
yield to_be_yield
119136

120137
@bp.route('/chat', methods=['POST'])
121138
async def chat():
122139
thread_id = request.cookies.get('thread_id')
123140
agent_id = request.cookies.get('agent_id')
124141
thread = None
125142

126-
if thread_id or agent_id != bp.agent.id:
143+
if thread_id and agent_id == bp.agent.id:
127144
# Check if the thread is still active
128145
try:
129146
thread = await bp.ai_client.agents.get_thread(thread_id)
@@ -159,17 +176,14 @@ async def chat():
159176

160177
@bp.route('/fetch-document', methods=['GET'])
161178
async def fetch_document():
162-
filename = "product_info_1.md"
163-
164-
# Get the file path from the mapping
165-
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'files', filename))
166-
167-
if not os.path.exists(file_path):
168-
return jsonify({"error": f"File not found: {filename}"}), 404
179+
file_id = request.args.get('file_id')
180+
current_app.logger.info(f"Fetching document: {file_id}")
181+
if not file_id:
182+
return jsonify({"error": "file_id is required"}), 400
169183

170184
try:
171185
# Read the file content asynchronously using asyncio.to_thread
172-
data = await asyncio.to_thread(read_file, file_path)
186+
data = await asyncio.to_thread(read_file, bp.files[file_id])
173187
return Response(data, content_type='text/plain')
174188

175189
except Exception as e:

src/quartapp/static/ChatClient.js

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class ChatClient {
3737
let accumulatedContent = '';
3838
let isStreaming = true;
3939
let buffer = '';
40+
let annotations = [];
4041

4142
const reader = stream.getReader();
4243
const decoder = new TextDecoder();
@@ -73,12 +74,13 @@ class ChatClient {
7374
if (data.type === "completed_message") {
7475
this.ui.clearAssistantMessage(messageDiv);
7576
accumulatedContent = data.content;
77+
annotations = data.annotations;
7678
isStreaming = false;
7779
} else {
7880
accumulatedContent += data.content;
7981
}
8082

81-
this.ui.appendAssistantMessage(messageDiv, accumulatedContent, isStreaming);
83+
this.ui.appendAssistantMessage(messageDiv, accumulatedContent, isStreaming, annotations);
8284
}
8385
}
8486

src/quartapp/static/ChatUI.js

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,31 @@ class ChatUI {
1313
this.attachCloseButtonListener();
1414
}
1515

16-
preprocessContent(content) {
17-
// Regular expression to find citations like 【n:m†filename.md】
18-
const citationRegex = /\u3010(\d+):(\d+)\u2020([^\s]+)\u3011/g;
19-
return content.replace(citationRegex, (match, _, __, filename) => {
20-
return `<a href="#" class="file-citation" data-file-name="${filename}">${match}</a>`;
21-
});
16+
preprocessContent(content, annotations) {
17+
if (annotations) {
18+
annotations.slice().reverse().forEach(annotation => {
19+
// the start and end index are the label of annotation. Replace them with a link
20+
content = content.slice(0, annotation.start_index) +
21+
`<a href="#" class="file-citation" data-file-id="${annotation.file_citation.file_id}">${annotation.text}</a>` +
22+
content.slice(annotation.end_index);
23+
});
24+
}
25+
return content;
2226
}
2327

2428
addCitationClickListener() {
2529
document.addEventListener('click', (event) => {
2630
if (event.target.classList.contains('file-citation')) {
2731
event.preventDefault();
28-
const filename = event.target.getAttribute('data-file-name');
29-
this.loadDocument(filename);
32+
const file_id = event.target.getAttribute('data-file-id');
33+
this.loadDocument(file_id);
3034
}
3135
});
3236
}
3337

34-
async loadDocument(filename) {
38+
async loadDocument(file_id) {
3539
try {
36-
const response = await fetch(`/fetch-document?filename=${filename}`);
40+
const response = await fetch(`/fetch-document?file_id=${file_id}`);
3741
if (!response.ok) {
3842
throw new Error('Network response was not ok');
3943
}
@@ -53,7 +57,6 @@ class ChatUI {
5357
}
5458

5559
showDocument(content) {
56-
console.log("showDocument:", content);
5760
const docViewerSection = document.getElementById("document-viewer-section");
5861
const chatColumn = document.getElementById("chat-container");
5962

@@ -109,8 +112,7 @@ class ChatUI {
109112
this.scrollToBottom();
110113
}
111114

112-
appendAssistantMessage(messageDiv, accumulatedContent, isStreaming) {
113-
//console.log("Accumulated Content before conversion:", accumulatedContent);
115+
appendAssistantMessage(messageDiv, accumulatedContent, isStreaming, annotations) {
114116
const md = window.markdownit({
115117
html: true,
116118
linkify: true,
@@ -120,7 +122,7 @@ class ChatUI {
120122

121123
try {
122124
// Preprocess content to convert citations to links
123-
const preprocessedContent = this.preprocessContent(accumulatedContent);
125+
const preprocessedContent = this.preprocessContent(accumulatedContent, annotations);
124126
// Convert the accumulated content to HTML using markdown-it
125127
let htmlContent = md.render(preprocessedContent);
126128
const messageTextDiv = messageDiv.querySelector(".message-text");
@@ -130,13 +132,10 @@ class ChatUI {
130132

131133
// Set the innerHTML of the message text div to the HTML content
132134
messageTextDiv.innerHTML = htmlContent;
133-
console.log("HTML set to messageTextDiv:", messageTextDiv.innerHTML);
134135

135136
// Use requestAnimationFrame to ensure the DOM has updated before scrolling
136137
// Only scroll if not streaming
137138
if (!isStreaming) {
138-
console.log("Accumulated content:", accumulatedContent);
139-
console.log("HTML set to messageTextDiv:", messageTextDiv.innerHTML);
140139
requestAnimationFrame(() => {
141140
this.scrollToBottom();
142141
});

src/requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ anyio==4.3.0
1919
# watchfiles
2020
attrs==23.2.0
2121
# via aiohttp
22-
azure-core==1.30.1
22+
azure-core==1.31.0
2323
# via azure-identity
2424
azure-identity==1.15.0
2525
# via quartapp (pyproject.toml)
@@ -146,7 +146,7 @@ sniffio==1.3.1
146146
# openai
147147
tqdm==4.66.2
148148
# via openai
149-
typing-extensions==4.11.0
149+
typing-extensions==4.12.2
150150
# via
151151
# azure-core
152152
# openai
@@ -171,5 +171,5 @@ wsproto==1.2.0
171171
# via hypercorn
172172
yarl==1.9.4
173173
# via aiohttp
174-
./packages/azure_ai_projects-1.0.0b1-py3-none-any.whl
174+
azure-ai-projects==1.0.0b4
175175
# via quartapp (pyproject.toml)

0 commit comments

Comments
 (0)