1
1
# Copyright (c) Microsoft. All rights reserved.
2
2
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
3
3
4
- from typing import Any
4
+ from typing import AsyncGenerator , Dict , Optional , Tuple
5
5
from quart import Blueprint , jsonify , request , Response , render_template , current_app
6
6
7
7
import asyncio
12
12
from azure .identity import DefaultAzureCredential
13
13
14
14
from azure .ai .projects .models import (
15
- MessageDeltaTextContent ,
16
15
MessageDeltaChunk ,
17
16
ThreadMessage ,
18
17
FileSearchTool ,
19
18
AsyncToolSet ,
20
19
FilePurpose ,
21
- AgentStreamEvent
20
+ ThreadMessage ,
21
+ StreamEventData ,
22
+ AsyncAgentEventHandler ,
23
+ Agent ,
24
+ VectorStore
22
25
)
23
26
24
- bp = Blueprint ("chat" , __name__ , template_folder = "templates" , static_folder = "static" )
27
+ class ChatBlueprint (Blueprint ):
28
+ ai_client : AIProjectClient
29
+ agent : Agent
30
+ files : Dict [str , str ]
31
+ vector_store : VectorStore
32
+
33
+ bp = ChatBlueprint ("chat" , __name__ , template_folder = "templates" , static_folder = "static" )
34
+
35
+ class MyEventHandler (AsyncAgentEventHandler [str ]):
36
+
37
+ async def on_message_delta (
38
+ self , delta : "MessageDeltaChunk"
39
+ ) -> Optional [str ]:
40
+ stream_data = json .dumps ({'content' : delta .text , 'type' : "message" })
41
+ return f"data: { stream_data } \n \n "
42
+
43
+ async def on_thread_message (
44
+ self , message : "ThreadMessage"
45
+ ) -> Optional [str ]:
46
+ if message .status == "completed" :
47
+ annotations = [annotation .as_dict () for annotation in message .file_citation_annotations ]
48
+ stream_data = json .dumps ({'content' : message .text_messages [0 ].text .value , 'annotations' : annotations , 'type' : "completed_message" })
49
+ return f"data: { stream_data } \n \n "
50
+ return None
51
+
52
+ async def on_error (self , data : str ) -> Optional [str ]:
53
+ print (f"An error occurred. Data: { data } " )
54
+ stream_data = json .dumps ({'type' : "stream_end" })
55
+ return f"data: { stream_data } \n \n "
56
+
57
+ async def on_done (
58
+ self ,
59
+ ) -> Optional [str ]:
60
+ stream_data = json .dumps ({'type' : "stream_end" })
61
+ return f"data: { stream_data } \n \n "
62
+
25
63
26
64
27
65
@bp .before_app_serving
@@ -33,15 +71,15 @@ async def start_server():
33
71
)
34
72
35
73
# TODO: add more files are not supported for citation at the moment
36
- files = ["product_info_1.md" ]
37
- file_ids = []
38
- for file in files :
39
- file_path = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , 'files' , file ))
74
+ file_names = ["product_info_1.md" , "product_info_2 .md" ]
75
+ files : Dict [ str , str ] = {}
76
+ for file_name in file_names :
77
+ file_path = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , 'files' , file_name ))
40
78
print (f"Uploading file { file_path } " )
41
79
file = await ai_client .agents .upload_file_and_poll (file_path = file_path , purpose = FilePurpose .AGENTS )
42
- file_ids . append ( file .id )
80
+ files . update ({ file .id : file_path } )
43
81
44
- vector_store = await ai_client .agents .create_vector_store (file_ids = file_ids , name = "sample_store" )
82
+ vector_store = await ai_client .agents .create_vector_store_and_poll (file_ids = list ( files . keys ()) , name = "sample_store" )
45
83
46
84
file_search_tool = FileSearchTool (vector_store_ids = [vector_store .id ])
47
85
@@ -59,12 +97,12 @@ async def start_server():
59
97
bp .ai_client = ai_client
60
98
bp .agent = agent
61
99
bp .vector_store = vector_store
62
- bp .file_ids = file_ids
100
+ bp .files = files
63
101
64
102
65
103
@bp .after_app_serving
66
104
async def stop_server ():
67
- for file_id in bp .file_ids :
105
+ for file_id in bp .files . keys () :
68
106
await bp .ai_client .agents .delete_file (file_id )
69
107
print (f"Deleted file { file_id } " )
70
108
@@ -78,47 +116,32 @@ async def stop_server():
78
116
await bp .ai_client .close ()
79
117
print ("Closed AIProjectClient" )
80
118
119
+
120
+
121
+
81
122
@bp .get ("/" )
82
123
async def index ():
83
124
return await render_template ("index.html" )
84
125
85
- async def create_stream (thread_id : str , agent_id : str ):
126
+
127
+
128
+ async def get_result (thread_id : str , agent_id : str ) -> AsyncGenerator [str , None ]:
86
129
async with await bp .ai_client .agents .create_stream (
87
- thread_id = thread_id , assistant_id = agent_id
130
+ thread_id = thread_id , assistant_id = agent_id ,
131
+ event_handler = MyEventHandler ()
88
132
) as stream :
89
- accumulated_text = ""
90
-
91
- async for event_type , event_data in stream :
92
-
93
- stream_data = None
94
- if isinstance (event_data , MessageDeltaChunk ):
95
- for content_part in event_data .delta .content :
96
- if isinstance (content_part , MessageDeltaTextContent ):
97
- text_value = content_part .text .value if content_part .text else "No text"
98
- accumulated_text += text_value
99
- print (f"Text delta received: { text_value } " )
100
- stream_data = json .dumps ({'content' : text_value , 'type' : "message" })
101
-
102
- elif isinstance (event_data , ThreadMessage ):
103
- print (f"ThreadMessage created. ID: { event_data .id } , Status: { event_data .status } " )
104
- if (event_data .status == "completed" ):
105
- stream_data = json .dumps ({'content' : accumulated_text , 'type' : "completed_message" })
106
-
107
- elif event_type == AgentStreamEvent .DONE :
108
- print ("Stream completed." )
109
- stream_data = json .dumps ({'type' : "stream_end" })
110
-
111
- if stream_data :
112
- yield f"data: { stream_data } \n \n "
133
+ # Iterate over the steam to trigger event functions
134
+ async for _ , _ , event_func_return_val in stream :
135
+ if event_func_return_val :
136
+ yield event_func_return_val
113
137
114
-
115
138
@bp .route ('/chat' , methods = ['POST' ])
116
139
async def chat ():
117
140
thread_id = request .cookies .get ('thread_id' )
118
141
agent_id = request .cookies .get ('agent_id' )
119
142
thread = None
120
143
121
- if thread_id or agent_id ! = bp .agent .id :
144
+ if thread_id and agent_id = = bp .agent .id :
122
145
# Check if the thread is still active
123
146
try :
124
147
thread = await bp .ai_client .agents .get_thread (thread_id )
@@ -147,24 +170,21 @@ async def chat():
147
170
'Content-Type' : 'text/event-stream'
148
171
}
149
172
150
- response = Response (create_stream (thread_id , agent_id ), headers = headers )
173
+ response = Response (get_result (thread_id , agent_id ), headers = headers )
151
174
response .set_cookie ('thread_id' , thread_id )
152
175
response .set_cookie ('agent_id' , agent_id )
153
176
return response
154
177
155
178
@bp .route ('/fetch-document' , methods = ['GET' ])
156
179
async def fetch_document ():
157
- filename = "product_info_1.md"
158
-
159
- # Get the file path from the mapping
160
- file_path = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , 'files' , filename ))
161
-
162
- if not os .path .exists (file_path ):
163
- return jsonify ({"error" : f"File not found: { filename } " }), 404
180
+ file_id = request .args .get ('file_id' )
181
+ current_app .logger .info (f"Fetching document: { file_id } " )
182
+ if not file_id :
183
+ return jsonify ({"error" : "file_id is required" }), 400
164
184
165
185
try :
166
186
# Read the file content asynchronously using asyncio.to_thread
167
- data = await asyncio .to_thread (read_file , file_path )
187
+ data = await asyncio .to_thread (read_file , bp . files [ file_id ] )
168
188
return Response (data , content_type = 'text/plain' )
169
189
170
190
except Exception as e :
0 commit comments