-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathroutes.py
More file actions
361 lines (306 loc) · 14.1 KB
/
routes.py
File metadata and controls
361 lines (306 loc) · 14.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
import asyncio
import json
import os
from typing import AsyncGenerator, Optional, Dict
import fastapi
from fastapi import Request, Depends, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from fastapi.responses import JSONResponse
import logging
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from azure.ai.agents.aio import AgentsClient
from azure.ai.agents.models import (
Agent,
MessageDeltaChunk,
ThreadMessage,
ThreadRun,
AsyncAgentEventHandler,
RunStep
)
from azure.ai.projects import AIProjectClient
from azure.ai.projects.models import (
AgentEvaluationRequest,
AgentEvaluationSamplingConfiguration,
AgentEvaluationRedactionConfiguration,
EvaluatorIds
)
# Create a logger for this module
logger = logging.getLogger("azureaiapp")
# Set the log level for the azure HTTP logging policy to WARNING (or ERROR)
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)
from opentelemetry import trace
tracer = trace.get_tracer(__name__)
# Define the directory for your templates.
directory = os.path.join(os.path.dirname(__file__), "templates")
templates = Jinja2Templates(directory=directory)
# Create a new FastAPI router
router = fastapi.APIRouter()
def get_ai_project(request: Request) -> AIProjectClient:
return request.app.state.ai_project
def get_agent_client(request: Request) -> AgentsClient:
return request.app.state.agent_client
def get_agent(request: Request) -> Agent:
return request.app.state.agent
def get_app_insights_conn_str(request: Request) -> str:
if hasattr(request.app.state, "application_insights_connection_string"):
return request.app.state.application_insights_connection_string
else:
return None
def serialize_sse_event(data: Dict) -> str:
return f"data: {json.dumps(data)}\n\n"
async def get_message_and_annotations(agent_client : AgentsClient, message: ThreadMessage) -> Dict:
annotations = []
# Get file annotations for the file search.
for annotation in (a.as_dict() for a in message.file_citation_annotations):
file_id = annotation["file_citation"]["file_id"]
logger.info(f"Fetching file with ID for annotation {file_id}")
openai_file = await agent_client.files.get(file_id)
annotation["file_name"] = openai_file.filename
logger.info(f"File name for annotation: {annotation['file_name']}")
annotations.append(annotation)
# Get url annotation for the index search.
for url_annotation in message.url_citation_annotations:
annotation = url_annotation.as_dict()
annotation["file_name"] = annotation['url_citation']['title']
logger.info(f"File name for annotation: {annotation['file_name']}")
annotations.append(annotation)
return {
'content': message.text_messages[0].text.value,
'annotations': annotations
}
class MyEventHandler(AsyncAgentEventHandler[str]):
def __init__(self, ai_project: AIProjectClient, app_insights_conn_str: str):
super().__init__()
self.agent_client = ai_project.agents
self.ai_project = ai_project
self.app_insights_conn_str = app_insights_conn_str
async def on_message_delta(self, delta: MessageDeltaChunk) -> Optional[str]:
stream_data = {'content': delta.text, 'type': "message"}
return serialize_sse_event(stream_data)
async def on_thread_message(self, message: ThreadMessage) -> Optional[str]:
try:
logger.info(f"MyEventHandler: Received thread message, message ID: {message.id}, status: {message.status}")
if message.status != "completed":
return None
logger.info("MyEventHandler: Received completed message")
stream_data = await get_message_and_annotations(self.agent_client, message)
stream_data['type'] = "completed_message"
return serialize_sse_event(stream_data)
except Exception as e:
logger.error(f"Error in event handler for thread message: {e}", exc_info=True)
return None
async def on_thread_run(self, run: ThreadRun) -> Optional[str]:
logger.info("MyEventHandler: on_thread_run event received")
run_information = f"ThreadRun status: {run.status}, thread ID: {run.thread_id}"
stream_data = {'content': run_information, 'type': 'thread_run'}
if run.status == "failed":
stream_data['error'] = run.last_error.as_dict()
# automatically run agent evaluation when the run is completed
if run.status == "completed":
run_agent_evaluation(run.thread_id, run.id, self.ai_project, self.app_insights_conn_str)
return serialize_sse_event(stream_data)
async def on_error(self, data: str) -> Optional[str]:
logger.error(f"MyEventHandler: on_error event received: {data}")
stream_data = {'type': "stream_end"}
return serialize_sse_event(stream_data)
async def on_done(self) -> Optional[str]:
logger.info("MyEventHandler: on_done event received")
stream_data = {'type': "stream_end"}
return serialize_sse_event(stream_data)
async def on_run_step(self, step: RunStep) -> Optional[str]:
logger.info(f"Step {step['id']} status: {step['status']}")
step_details = step.get("step_details", {})
tool_calls = step_details.get("tool_calls", [])
if tool_calls:
logger.info("Tool calls:")
for call in tool_calls:
azure_ai_search_details = call.get("azure_ai_search", {})
if azure_ai_search_details:
logger.info(f"azure_ai_search input: {azure_ai_search_details.get('input')}")
logger.info(f"azure_ai_search output: {azure_ai_search_details.get('output')}")
return None
@router.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse(
"index.html",
{
"request": request,
}
)
async def get_result(
request: Request,
thread_id: str,
agent_id: str,
ai_project: AIProjectClient,
app_insight_conn_str: Optional[str],
carrier: Dict[str, str]
) -> AsyncGenerator[str, None]:
ctx = TraceContextTextMapPropagator().extract(carrier=carrier)
with tracer.start_as_current_span('get_result', context=ctx):
logger.info(f"get_result invoked for thread_id={thread_id} and agent_id={agent_id}")
try:
agent_client = ai_project.agents
async with await agent_client.runs.stream(
thread_id=thread_id,
agent_id=agent_id,
event_handler=MyEventHandler(ai_project, app_insight_conn_str),
) as stream:
logger.info("Successfully created stream; starting to process events")
async for event in stream:
_, _, event_func_return_val = event
logger.debug(f"Received event: {event}")
if event_func_return_val:
logger.info(f"Yielding event: {event_func_return_val}")
yield event_func_return_val
else:
logger.debug("Event received but no data to yield")
except Exception as e:
logger.exception(f"Exception in get_result: {e}")
yield serialize_sse_event({'type': "error", 'message': str(e)})
@router.get("/chat/history")
async def history(
request: Request,
ai_project : AIProjectClient = Depends(get_ai_project),
agent : Agent = Depends(get_agent),
):
with tracer.start_as_current_span("chat_history"):
# Retrieve the thread ID from the cookies (if available).
thread_id = request.cookies.get('thread_id')
agent_id = request.cookies.get('agent_id')
# Attempt to get an existing thread. If not found, create a new one.
try:
agent_client = ai_project.agents
if thread_id and agent_id == agent.id:
logger.info(f"Retrieving thread with ID {thread_id}")
thread = await agent_client.threads.get(thread_id)
else:
logger.info("Creating a new thread")
thread = await agent_client.threads.create()
except Exception as e:
logger.error(f"Error handling thread: {e}")
raise HTTPException(status_code=400, detail=f"Error handling thread: {e}")
thread_id = thread.id
agent_id = agent.id
# Create a new message from the user's input.
try:
content = []
response = agent_client.messages.list(
thread_id=thread_id,
)
async for message in response:
formatteded_message = await get_message_and_annotations(agent_client, message)
formatteded_message['role'] = message.role
formatteded_message['created_at'] = message.created_at.astimezone().strftime("%m/%d/%y, %I:%M %p")
content.append(formatteded_message)
logger.info(f"List message, thread ID: {thread_id}")
response = JSONResponse(content=content)
# Update cookies to persist the thread and agent IDs.
response.set_cookie("thread_id", thread_id)
response.set_cookie("agent_id", agent_id)
return response
except Exception as e:
logger.error(f"Error listing message: {e}")
raise HTTPException(status_code=500, detail=f"Error list message: {e}")
@router.get("/agent")
async def get_chat_agent(
request: Request
):
return JSONResponse(content=get_agent(request).as_dict())
@router.post("/chat")
async def chat(
request: Request,
agent : Agent = Depends(get_agent),
ai_project: AIProjectClient = Depends(get_ai_project),
app_insights_conn_str : str = Depends(get_app_insights_conn_str)
):
# Retrieve the thread ID from the cookies (if available).
thread_id = request.cookies.get('thread_id')
agent_id = request.cookies.get('agent_id')
with tracer.start_as_current_span("chat_request"):
carrier = {}
TraceContextTextMapPropagator().inject(carrier)
# Attempt to get an existing thread. If not found, create a new one.
try:
agent_client = ai_project.agents
if thread_id and agent_id == agent.id:
logger.info(f"Retrieving thread with ID {thread_id}")
thread = await agent_client.threads.get(thread_id)
else:
logger.info("Creating a new thread")
thread = await agent_client.threads.create()
except Exception as e:
logger.error(f"Error handling thread: {e}")
raise HTTPException(status_code=400, detail=f"Error handling thread: {e}")
thread_id = thread.id
agent_id = agent.id
# Parse the JSON from the request.
try:
user_message = await request.json()
except Exception as e:
logger.error(f"Invalid JSON in request: {e}")
raise HTTPException(status_code=400, detail=f"Invalid JSON in request: {e}")
logger.info(f"user_message: {user_message}")
# Create a new message from the user's input.
try:
message = await agent_client.messages.create(
thread_id=thread_id,
role="user",
content=user_message.get('message', '')
)
logger.info(f"Created message, message ID: {message.id}")
except Exception as e:
logger.error(f"Error creating message: {e}")
raise HTTPException(status_code=500, detail=f"Error creating message: {e}")
# Set the Server-Sent Events (SSE) response headers.
headers = {
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream"
}
logger.info(f"Starting streaming response for thread ID {thread_id}")
# Create the streaming response using the generator.
response = StreamingResponse(get_result(request, thread_id, agent_id, ai_project, app_insights_conn_str, carrier), headers=headers)
# Update cookies to persist the thread and agent IDs.
response.set_cookie("thread_id", thread_id)
response.set_cookie("agent_id", agent_id)
return response
def read_file(path: str) -> str:
with open(path, 'r') as file:
return file.read()
def run_agent_evaluation(
thread_id: str,
run_id: str,
ai_project: AIProjectClient,
app_insights_conn_str: str):
if app_insights_conn_str:
agent_evaluation_request = AgentEvaluationRequest(
run_id=run_id,
thread_id=thread_id,
evaluators={
"Relevance": {"Id": EvaluatorIds.RELEVANCE.value},
"TaskAdherence": {"Id": EvaluatorIds.TASK_ADHERENCE.value},
"ToolCallAccuracy": {"Id": EvaluatorIds.TOOL_CALL_ACCURACY.value},
},
sampling_configuration=AgentEvaluationSamplingConfiguration(
name="default",
sampling_percent=100,
),
redaction_configuration=AgentEvaluationRedactionConfiguration(
redact_score_properties=False,
),
app_insights_connection_string=app_insights_conn_str,
)
async def run_evaluation():
try:
logger.info(f"Running agent evaluation on thread ID {thread_id} and run ID {run_id}")
agent_evaluation_response = await ai_project.evaluations.create_agent_evaluation(
evaluation=agent_evaluation_request
)
logger.info(f"Evaluation response: {agent_evaluation_response}")
except Exception as e:
logger.error(f"Error creating agent evaluation: {e}")
# Create a new task to run the evaluation asynchronously
asyncio.create_task(run_evaluation())