|
1 | 1 | # Copyright (c) Microsoft. All rights reserved.
|
2 | 2 | # Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
3 | 3 |
|
4 |
| -from typing import Any |
| 4 | +from typing import Any, AsyncGenerator, Optional, Tuple |
5 | 5 | from quart import Blueprint, jsonify, request, Response, render_template, current_app
|
6 | 6 |
|
7 | 7 | import asyncio
|
|
18 | 18 | FileSearchTool,
|
19 | 19 | AsyncToolSet,
|
20 | 20 | FilePurpose,
|
| 21 | + ThreadMessage, |
| 22 | + ThreadError, |
| 23 | + StreamEventData, |
21 | 24 | AgentStreamEvent
|
22 | 25 | )
|
23 | 26 |
|
@@ -78,40 +81,44 @@ async def stop_server():
|
78 | 81 | await bp.ai_client.close()
|
79 | 82 | print("Closed AIProjectClient")
|
80 | 83 |
|
| 84 | +async def yield_callback(event_type: str, event_obj: StreamEventData, **kwargs) -> Optional[str]: |
| 85 | + accumulated_text = kwargs['accumulated_text'] |
| 86 | + if (isinstance(event_obj, MessageDeltaChunk)): |
| 87 | + for content_part in event_obj.delta.content: |
| 88 | + if isinstance(content_part, MessageDeltaTextContent): |
| 89 | + text_value = content_part.text.value if content_part.text else "No text" |
| 90 | + stream_data = json.dumps({'content': text_value, 'type': "message"}) |
| 91 | + accumulated_text[0] += text_value |
| 92 | + return f"data: {stream_data}\n\n" |
| 93 | + elif isinstance(event_obj, ThreadMessage): |
| 94 | + if (event_obj.status == "completed"): |
| 95 | + stream_data = json.dumps({'content': accumulated_text[0], 'type': "completed_message"}) |
| 96 | + return f"data: {stream_data}\n\n" |
| 97 | + elif isinstance(event_obj, ThreadError): |
| 98 | + print(f"An error occurred. Data: {event_obj.error}") |
| 99 | + stream_data = json.dumps({'type': "stream_end"}) |
| 100 | + return f"data: {stream_data}\n\n" |
| 101 | + elif event_type == AgentStreamEvent.DONE: |
| 102 | + stream_data = json.dumps({'type': "stream_end"}) |
| 103 | + return f"data: {stream_data}\n\n" |
| 104 | + |
| 105 | + return None |
81 | 106 | @bp.get("/")
|
82 | 107 | async def index():
|
83 | 108 | return await render_template("index.html")
|
84 | 109 |
|
85 |
| -async def create_stream(thread_id: str, agent_id: str): |
| 110 | + |
| 111 | + |
| 112 | +async def get_result(thread_id: str, agent_id: str): |
| 113 | + |
| 114 | + accumulated_text = [""] |
| 115 | + |
86 | 116 | async with await bp.ai_client.agents.create_stream(
|
87 |
| - thread_id=thread_id, assistant_id=agent_id |
| 117 | + thread_id=thread_id, assistant_id=agent_id, |
88 | 118 | ) as stream:
|
89 |
| - accumulated_text = "" |
90 |
| - |
91 |
| - async for event_type, event_data in stream: |
92 |
| - |
93 |
| - stream_data = None |
94 |
| - if isinstance(event_data, MessageDeltaChunk): |
95 |
| - for content_part in event_data.delta.content: |
96 |
| - if isinstance(content_part, MessageDeltaTextContent): |
97 |
| - text_value = content_part.text.value if content_part.text else "No text" |
98 |
| - accumulated_text += text_value |
99 |
| - print(f"Text delta received: {text_value}") |
100 |
| - stream_data = json.dumps({'content': text_value, 'type': "message"}) |
101 |
| - |
102 |
| - elif isinstance(event_data, ThreadMessage): |
103 |
| - print(f"ThreadMessage created. ID: {event_data.id}, Status: {event_data.status}") |
104 |
| - if (event_data.status == "completed"): |
105 |
| - stream_data = json.dumps({'content': accumulated_text, 'type': "completed_message"}) |
106 |
| - |
107 |
| - elif event_type == AgentStreamEvent.DONE: |
108 |
| - print("Stream completed.") |
109 |
| - stream_data = json.dumps({'type': "stream_end"}) |
110 |
| - |
111 |
| - if stream_data: |
112 |
| - yield f"data: {stream_data}\n\n" |
| 119 | + async for to_be_yield in stream.yield_until_done(yield_callback, accumulated_text=accumulated_text): |
| 120 | + yield to_be_yield |
113 | 121 |
|
114 |
| - |
115 | 122 | @bp.route('/chat', methods=['POST'])
|
116 | 123 | async def chat():
|
117 | 124 | thread_id = request.cookies.get('thread_id')
|
@@ -147,7 +154,7 @@ async def chat():
|
147 | 154 | 'Content-Type': 'text/event-stream'
|
148 | 155 | }
|
149 | 156 |
|
150 |
| - response = Response(create_stream(thread_id, agent_id), headers=headers) |
| 157 | + response = Response(get_result(thread_id, agent_id), headers=headers) |
151 | 158 | response.set_cookie('thread_id', thread_id)
|
152 | 159 | response.set_cookie('agent_id', agent_id)
|
153 | 160 | return response
|
|
0 commit comments