1
1
# Copyright (c) Microsoft. All rights reserved.
2
2
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
3
3
4
- from typing import Any , AsyncGenerator , Optional , Tuple
4
+ from typing import AsyncGenerator , Dict , Optional , Tuple
5
5
from quart import Blueprint , jsonify , request , Response , render_template , current_app
6
6
7
7
import asyncio
12
12
from azure .identity import DefaultAzureCredential
13
13
14
14
from azure .ai .projects .models import (
15
- MessageDeltaTextContent ,
16
15
MessageDeltaChunk ,
17
16
ThreadMessage ,
18
17
FileSearchTool ,
19
18
AsyncToolSet ,
20
19
FilePurpose ,
21
20
ThreadMessage ,
22
- ThreadError ,
23
21
StreamEventData ,
24
- AgentStreamEvent
22
+ AsyncAgentEventHandler ,
23
+ Agent ,
24
+ VectorStore
25
25
)
26
26
27
- bp = Blueprint ("chat" , __name__ , template_folder = "templates" , static_folder = "static" )
27
+ class ChatBlueprint (Blueprint ):
28
+ ai_client : AIProjectClient
29
+ agent : Agent
30
+ files : Dict [str , str ]
31
+ vector_store : VectorStore
32
+
33
+ bp = ChatBlueprint ("chat" , __name__ , template_folder = "templates" , static_folder = "static" )
34
+
35
+ class MyEventHandler (AsyncAgentEventHandler [str ]):
36
+
37
+ async def on_message_delta (
38
+ self , delta : "MessageDeltaChunk"
39
+ ) -> Optional [str ]:
40
+ stream_data = json .dumps ({'content' : delta .text , 'type' : "message" })
41
+ return f"data: { stream_data } \n \n "
42
+
43
+ async def on_thread_message (
44
+ self , message : "ThreadMessage"
45
+ ) -> Optional [str ]:
46
+ if message .status == "completed" :
47
+ annotations = [annotation .as_dict () for annotation in message .file_citation_annotations ]
48
+ stream_data = json .dumps ({'content' : message .text_messages [0 ].text .value , 'annotations' : annotations , 'type' : "completed_message" })
49
+ return f"data: { stream_data } \n \n "
50
+ return None
51
+
52
+ async def on_error (self , data : str ) -> Optional [str ]:
53
+ print (f"An error occurred. Data: { data } " )
54
+ stream_data = json .dumps ({'type' : "stream_end" })
55
+ return f"data: { stream_data } \n \n "
56
+
57
+ async def on_done (
58
+ self ,
59
+ ) -> Optional [str ]:
60
+ stream_data = json .dumps ({'type' : "stream_end" })
61
+ return f"data: { stream_data } \n \n "
62
+
28
63
29
64
30
65
@bp .before_app_serving
@@ -36,15 +71,15 @@ async def start_server():
36
71
)
37
72
38
73
# TODO: add more files are not supported for citation at the moment
39
- files = ["product_info_1.md" ]
40
- file_ids = []
41
- for file in files :
42
- file_path = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , 'files' , file ))
74
+ file_names = ["product_info_1.md" , "product_info_2 .md" ]
75
+ files : Dict [ str , str ] = {}
76
+ for file_name in file_names :
77
+ file_path = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , 'files' , file_name ))
43
78
print (f"Uploading file { file_path } " )
44
79
file = await ai_client .agents .upload_file_and_poll (file_path = file_path , purpose = FilePurpose .AGENTS )
45
- file_ids . append ( file .id )
80
+ files . update ({ file .id : file_path } )
46
81
47
- vector_store = await ai_client .agents .create_vector_store (file_ids = file_ids , name = "sample_store" )
82
+ vector_store = await ai_client .agents .create_vector_store_and_poll (file_ids = list ( files . keys ()) , name = "sample_store" )
48
83
49
84
file_search_tool = FileSearchTool (vector_store_ids = [vector_store .id ])
50
85
@@ -62,12 +97,12 @@ async def start_server():
62
97
bp .ai_client = ai_client
63
98
bp .agent = agent
64
99
bp .vector_store = vector_store
65
- bp .file_ids = file_ids
100
+ bp .files = files
66
101
67
102
68
103
@bp .after_app_serving
69
104
async def stop_server ():
70
- for file_id in bp .file_ids :
105
+ for file_id in bp .files . keys () :
71
106
await bp .ai_client .agents .delete_file (file_id )
72
107
print (f"Deleted file { file_id } " )
73
108
@@ -81,49 +116,31 @@ async def stop_server():
81
116
await bp .ai_client .close ()
82
117
print ("Closed AIProjectClient" )
83
118
84
- async def yield_callback (event_type : str , event_obj : StreamEventData , ** kwargs ) -> Optional [str ]:
85
- accumulated_text = kwargs ['accumulated_text' ]
86
- if (isinstance (event_obj , MessageDeltaTextContent )):
87
- text_value = event_obj .text .value if event_obj .text else "No text"
88
- stream_data = json .dumps ({'content' : text_value , 'type' : "message" })
89
- accumulated_text [0 ] += text_value
90
- return f"data: { stream_data } \n \n "
91
- elif isinstance (event_obj , ThreadMessage ):
92
- if (event_obj .status == "completed" ):
93
- stream_data = json .dumps ({'content' : accumulated_text [0 ], 'type' : "completed_message" })
94
- return f"data: { stream_data } \n \n "
95
- elif isinstance (event_obj , ThreadError ):
96
- print (f"An error occurred. Data: { event_obj .error } " )
97
- stream_data = json .dumps ({'type' : "stream_end" })
98
- return f"data: { stream_data } \n \n "
99
- elif event_type == AgentStreamEvent .DONE :
100
- stream_data = json .dumps ({'type' : "stream_end" })
101
- return f"data: { stream_data } \n \n "
102
-
103
- return None
119
+
120
+
121
+
104
122
@bp .get ("/" )
105
123
async def index ():
106
124
return await render_template ("index.html" )
107
125
108
126
109
127
110
- async def get_result (thread_id : str , agent_id : str ):
111
-
112
- accumulated_text = ["" ]
113
-
128
+ async def get_result (thread_id : str , agent_id : str ) -> AsyncGenerator [str , None ]:
114
129
async with await bp .ai_client .agents .create_stream (
115
130
thread_id = thread_id , assistant_id = agent_id ,
131
+ event_handler = MyEventHandler ()
116
132
) as stream :
117
- async for to_be_yield in stream .yield_until_done (yield_callback , accumulated_text = accumulated_text ):
118
- yield to_be_yield
133
+ async for _ , _ , to_be_yield in stream :
134
+ if to_be_yield :
135
+ yield to_be_yield
119
136
120
137
@bp .route ('/chat' , methods = ['POST' ])
121
138
async def chat ():
122
139
thread_id = request .cookies .get ('thread_id' )
123
140
agent_id = request .cookies .get ('agent_id' )
124
141
thread = None
125
142
126
- if thread_id or agent_id ! = bp .agent .id :
143
+ if thread_id and agent_id = = bp .agent .id :
127
144
# Check if the thread is still active
128
145
try :
129
146
thread = await bp .ai_client .agents .get_thread (thread_id )
@@ -159,17 +176,14 @@ async def chat():
159
176
160
177
@bp .route ('/fetch-document' , methods = ['GET' ])
161
178
async def fetch_document ():
162
- filename = "product_info_1.md"
163
-
164
- # Get the file path from the mapping
165
- file_path = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , 'files' , filename ))
166
-
167
- if not os .path .exists (file_path ):
168
- return jsonify ({"error" : f"File not found: { filename } " }), 404
179
+ file_id = request .args .get ('file_id' )
180
+ current_app .logger .info (f"Fetching document: { file_id } " )
181
+ if not file_id :
182
+ return jsonify ({"error" : "file_id is required" }), 400
169
183
170
184
try :
171
185
# Read the file content asynchronously using asyncio.to_thread
172
- data = await asyncio .to_thread (read_file , file_path )
186
+ data = await asyncio .to_thread (read_file , bp . files [ file_id ] )
173
187
return Response (data , content_type = 'text/plain' )
174
188
175
189
except Exception as e :
0 commit comments