Skip to content

Commit 362e89f

Browse files
authored
Merge pull request #14 from Azure-Samples/howie/fully-customizable-event-handler
Howie/fully customizable event handler
2 parents 0ea9139 + 9cc2bc6 commit 362e89f

File tree

6 files changed

+92
-71
lines changed

6 files changed

+92
-71
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ To achieve this, when users submit a message to the web server, the web server w
4747

4848
## Local Development
4949

50-
1. Run `pip install -r requirements.txt`.
50+
1. Run `pip install -r ./src/requirements.txt`.
5151

5252
2. Make sure that the `.env` file exists.
5353

Binary file not shown.

src/quartapp/chat.py

Lines changed: 69 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Microsoft. All rights reserved.
22
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
33

4-
from typing import Any
4+
from typing import AsyncGenerator, Dict, Optional, Tuple
55
from quart import Blueprint, jsonify, request, Response, render_template, current_app
66

77
import asyncio
@@ -12,16 +12,54 @@
1212
from azure.identity import DefaultAzureCredential
1313

1414
from azure.ai.projects.models import (
15-
MessageDeltaTextContent,
1615
MessageDeltaChunk,
1716
ThreadMessage,
1817
FileSearchTool,
1918
AsyncToolSet,
2019
FilePurpose,
21-
AgentStreamEvent
20+
ThreadMessage,
21+
StreamEventData,
22+
AsyncAgentEventHandler,
23+
Agent,
24+
VectorStore
2225
)
2326

24-
bp = Blueprint("chat", __name__, template_folder="templates", static_folder="static")
27+
class ChatBlueprint(Blueprint):
28+
ai_client: AIProjectClient
29+
agent: Agent
30+
files: Dict[str, str]
31+
vector_store: VectorStore
32+
33+
bp = ChatBlueprint("chat", __name__, template_folder="templates", static_folder="static")
34+
35+
class MyEventHandler(AsyncAgentEventHandler[str]):
36+
37+
async def on_message_delta(
38+
self, delta: "MessageDeltaChunk"
39+
) -> Optional[str]:
40+
stream_data = json.dumps({'content': delta.text, 'type': "message"})
41+
return f"data: {stream_data}\n\n"
42+
43+
async def on_thread_message(
44+
self, message: "ThreadMessage"
45+
) -> Optional[str]:
46+
if message.status == "completed":
47+
annotations = [annotation.as_dict() for annotation in message.file_citation_annotations]
48+
stream_data = json.dumps({'content': message.text_messages[0].text.value, 'annotations': annotations, 'type': "completed_message"})
49+
return f"data: {stream_data}\n\n"
50+
return None
51+
52+
async def on_error(self, data: str) -> Optional[str]:
53+
print(f"An error occurred. Data: {data}")
54+
stream_data = json.dumps({'type': "stream_end"})
55+
return f"data: {stream_data}\n\n"
56+
57+
async def on_done(
58+
self,
59+
) -> Optional[str]:
60+
stream_data = json.dumps({'type': "stream_end"})
61+
return f"data: {stream_data}\n\n"
62+
2563

2664

2765
@bp.before_app_serving
@@ -33,15 +71,15 @@ async def start_server():
3371
)
3472

3573
# TODO: add more files are not supported for citation at the moment
36-
files = ["product_info_1.md"]
37-
file_ids = []
38-
for file in files:
39-
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'files', file))
74+
file_names = ["product_info_1.md", "product_info_2.md"]
75+
files: Dict[str, str] = {}
76+
for file_name in file_names:
77+
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'files', file_name))
4078
print(f"Uploading file {file_path}")
4179
file = await ai_client.agents.upload_file_and_poll(file_path=file_path, purpose=FilePurpose.AGENTS)
42-
file_ids.append(file.id)
80+
files.update({file.id: file_path})
4381

44-
vector_store = await ai_client.agents.create_vector_store(file_ids=file_ids, name="sample_store")
82+
vector_store = await ai_client.agents.create_vector_store_and_poll(file_ids=list(files.keys()), name="sample_store")
4583

4684
file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id])
4785

@@ -59,12 +97,12 @@ async def start_server():
5997
bp.ai_client = ai_client
6098
bp.agent = agent
6199
bp.vector_store = vector_store
62-
bp.file_ids = file_ids
100+
bp.files = files
63101

64102

65103
@bp.after_app_serving
66104
async def stop_server():
67-
for file_id in bp.file_ids:
105+
for file_id in bp.files.keys():
68106
await bp.ai_client.agents.delete_file(file_id)
69107
print(f"Deleted file {file_id}")
70108

@@ -78,47 +116,32 @@ async def stop_server():
78116
await bp.ai_client.close()
79117
print("Closed AIProjectClient")
80118

119+
120+
121+
81122
@bp.get("/")
82123
async def index():
83124
return await render_template("index.html")
84125

85-
async def create_stream(thread_id: str, agent_id: str):
126+
127+
128+
async def get_result(thread_id: str, agent_id: str) -> AsyncGenerator[str, None]:
86129
async with await bp.ai_client.agents.create_stream(
87-
thread_id=thread_id, assistant_id=agent_id
130+
thread_id=thread_id, assistant_id=agent_id,
131+
event_handler=MyEventHandler()
88132
) as stream:
89-
accumulated_text = ""
90-
91-
async for event_type, event_data in stream:
92-
93-
stream_data = None
94-
if isinstance(event_data, MessageDeltaChunk):
95-
for content_part in event_data.delta.content:
96-
if isinstance(content_part, MessageDeltaTextContent):
97-
text_value = content_part.text.value if content_part.text else "No text"
98-
accumulated_text += text_value
99-
print(f"Text delta received: {text_value}")
100-
stream_data = json.dumps({'content': text_value, 'type': "message"})
101-
102-
elif isinstance(event_data, ThreadMessage):
103-
print(f"ThreadMessage created. ID: {event_data.id}, Status: {event_data.status}")
104-
if (event_data.status == "completed"):
105-
stream_data = json.dumps({'content': accumulated_text, 'type': "completed_message"})
106-
107-
elif event_type == AgentStreamEvent.DONE:
108-
print("Stream completed.")
109-
stream_data = json.dumps({'type': "stream_end"})
110-
111-
if stream_data:
112-
yield f"data: {stream_data}\n\n"
133+
# Iterate over the steam to trigger event functions
134+
async for _, _, event_func_return_val in stream:
135+
if event_func_return_val:
136+
yield event_func_return_val
113137

114-
115138
@bp.route('/chat', methods=['POST'])
116139
async def chat():
117140
thread_id = request.cookies.get('thread_id')
118141
agent_id = request.cookies.get('agent_id')
119142
thread = None
120143

121-
if thread_id or agent_id != bp.agent.id:
144+
if thread_id and agent_id == bp.agent.id:
122145
# Check if the thread is still active
123146
try:
124147
thread = await bp.ai_client.agents.get_thread(thread_id)
@@ -147,24 +170,21 @@ async def chat():
147170
'Content-Type': 'text/event-stream'
148171
}
149172

150-
response = Response(create_stream(thread_id, agent_id), headers=headers)
173+
response = Response(get_result(thread_id, agent_id), headers=headers)
151174
response.set_cookie('thread_id', thread_id)
152175
response.set_cookie('agent_id', agent_id)
153176
return response
154177

155178
@bp.route('/fetch-document', methods=['GET'])
156179
async def fetch_document():
157-
filename = "product_info_1.md"
158-
159-
# Get the file path from the mapping
160-
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'files', filename))
161-
162-
if not os.path.exists(file_path):
163-
return jsonify({"error": f"File not found: {filename}"}), 404
180+
file_id = request.args.get('file_id')
181+
current_app.logger.info(f"Fetching document: {file_id}")
182+
if not file_id:
183+
return jsonify({"error": "file_id is required"}), 400
164184

165185
try:
166186
# Read the file content asynchronously using asyncio.to_thread
167-
data = await asyncio.to_thread(read_file, file_path)
187+
data = await asyncio.to_thread(read_file, bp.files[file_id])
168188
return Response(data, content_type='text/plain')
169189

170190
except Exception as e:

src/quartapp/static/ChatClient.js

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class ChatClient {
3737
let accumulatedContent = '';
3838
let isStreaming = true;
3939
let buffer = '';
40+
let annotations = [];
4041

4142
const reader = stream.getReader();
4243
const decoder = new TextDecoder();
@@ -73,12 +74,13 @@ class ChatClient {
7374
if (data.type === "completed_message") {
7475
this.ui.clearAssistantMessage(messageDiv);
7576
accumulatedContent = data.content;
77+
annotations = data.annotations;
7678
isStreaming = false;
7779
} else {
7880
accumulatedContent += data.content;
7981
}
8082

81-
this.ui.appendAssistantMessage(messageDiv, accumulatedContent, isStreaming);
83+
this.ui.appendAssistantMessage(messageDiv, accumulatedContent, isStreaming, annotations);
8284
}
8385
}
8486

src/quartapp/static/ChatUI.js

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,31 @@ class ChatUI {
1313
this.attachCloseButtonListener();
1414
}
1515

16-
preprocessContent(content) {
17-
// Regular expression to find citations like 【n:m†filename.md】
18-
const citationRegex = /\u3010(\d+):(\d+)\u2020([^\s]+)\u3011/g;
19-
return content.replace(citationRegex, (match, _, __, filename) => {
20-
return `<a href="#" class="file-citation" data-file-name="${filename}">${match}</a>`;
21-
});
16+
preprocessContent(content, annotations) {
17+
if (annotations) {
18+
annotations.slice().reverse().forEach(annotation => {
19+
// the start and end index are the label of annotation. Replace them with a link
20+
content = content.slice(0, annotation.start_index) +
21+
`<a href="#" class="file-citation" data-file-id="${annotation.file_citation.file_id}">${annotation.text}</a>` +
22+
content.slice(annotation.end_index);
23+
});
24+
}
25+
return content;
2226
}
2327

2428
addCitationClickListener() {
2529
document.addEventListener('click', (event) => {
2630
if (event.target.classList.contains('file-citation')) {
2731
event.preventDefault();
28-
const filename = event.target.getAttribute('data-file-name');
29-
this.loadDocument(filename);
32+
const file_id = event.target.getAttribute('data-file-id');
33+
this.loadDocument(file_id);
3034
}
3135
});
3236
}
3337

34-
async loadDocument(filename) {
38+
async loadDocument(file_id) {
3539
try {
36-
const response = await fetch(`/fetch-document?filename=${filename}`);
40+
const response = await fetch(`/fetch-document?file_id=${file_id}`);
3741
if (!response.ok) {
3842
throw new Error('Network response was not ok');
3943
}
@@ -53,7 +57,6 @@ class ChatUI {
5357
}
5458

5559
showDocument(content) {
56-
console.log("showDocument:", content);
5760
const docViewerSection = document.getElementById("document-viewer-section");
5861
const chatColumn = document.getElementById("chat-container");
5962

@@ -109,8 +112,7 @@ class ChatUI {
109112
this.scrollToBottom();
110113
}
111114

112-
appendAssistantMessage(messageDiv, accumulatedContent, isStreaming) {
113-
//console.log("Accumulated Content before conversion:", accumulatedContent);
115+
appendAssistantMessage(messageDiv, accumulatedContent, isStreaming, annotations) {
114116
const md = window.markdownit({
115117
html: true,
116118
linkify: true,
@@ -120,7 +122,7 @@ class ChatUI {
120122

121123
try {
122124
// Preprocess content to convert citations to links
123-
const preprocessedContent = this.preprocessContent(accumulatedContent);
125+
const preprocessedContent = this.preprocessContent(accumulatedContent, annotations);
124126
// Convert the accumulated content to HTML using markdown-it
125127
let htmlContent = md.render(preprocessedContent);
126128
const messageTextDiv = messageDiv.querySelector(".message-text");
@@ -130,13 +132,10 @@ class ChatUI {
130132

131133
// Set the innerHTML of the message text div to the HTML content
132134
messageTextDiv.innerHTML = htmlContent;
133-
console.log("HTML set to messageTextDiv:", messageTextDiv.innerHTML);
134135

135136
// Use requestAnimationFrame to ensure the DOM has updated before scrolling
136137
// Only scroll if not streaming
137138
if (!isStreaming) {
138-
console.log("Accumulated content:", accumulatedContent);
139-
console.log("HTML set to messageTextDiv:", messageTextDiv.innerHTML);
140139
requestAnimationFrame(() => {
141140
this.scrollToBottom();
142141
});

src/requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ anyio==4.3.0
1919
# watchfiles
2020
attrs==23.2.0
2121
# via aiohttp
22-
azure-core==1.30.1
22+
azure-core==1.31.0
2323
# via azure-identity
2424
azure-identity==1.15.0
2525
# via quartapp (pyproject.toml)
@@ -146,7 +146,7 @@ sniffio==1.3.1
146146
# openai
147147
tqdm==4.66.2
148148
# via openai
149-
typing-extensions==4.11.0
149+
typing-extensions==4.12.2
150150
# via
151151
# azure-core
152152
# openai
@@ -171,5 +171,5 @@ wsproto==1.2.0
171171
# via hypercorn
172172
yarl==1.9.4
173173
# via aiohttp
174-
./packages/azure_ai_projects-1.0.0b1-py3-none-any.whl
174+
azure-ai-projects==1.0.0b5
175175
# via quartapp (pyproject.toml)

0 commit comments

Comments
 (0)