Skip to content

Commit c6e86f0

Browse files
Delete thread when it expired from cache
1 parent daac493 commit c6e86f0

File tree

2 files changed

+69
-26
lines changed

2 files changed

+69
-26
lines changed

src/api/agents/agent_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ async def get_instance(cls, config):
3232
conn_str=config.azure_ai_project_conn_string
3333
)
3434

35-
agent_name = "agent"
35+
agent_name = "ConversationKnowledgeAgent"
3636
agent_instructions = '''You are a helpful assistant.
3737
Always return the citations as is in final response.
3838
Always return citation markers in the answer as [doc1], [doc2], etc.

src/api/services/chat_service.py

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import time
44
import uuid
55
from types import SimpleNamespace
6+
import asyncio
7+
import random
8+
import re
69

710
import openai
811
from fastapi import HTTPException, Request, status
@@ -16,8 +19,6 @@
1619
from helpers.utils import format_stream_response
1720
from cachetools import TTLCache
1821

19-
thread_cache = TTLCache(maxsize=1000, ttl=3600)
20-
2122
# Constants
2223
HOST_NAME = "CKM"
2324
HOST_INSTRUCTIONS = "Answer questions about call center operations"
@@ -26,8 +27,47 @@
2627
logging.basicConfig(level=logging.INFO)
2728
logger = logging.getLogger(__name__)
2829

30+
class ExpCache(TTLCache):
31+
"""
32+
Extended TTLCache that associates an agent and deletes Azure AI agent threads when items expire or are evicted (LRU).
33+
"""
34+
def __init__(self, *args, agent=None, **kwargs):
35+
super().__init__(*args, **kwargs)
36+
self.agent = agent
37+
38+
def expire(self, time=None):
39+
items = super().expire(time)
40+
for key, thread_id in items:
41+
try:
42+
if self.agent:
43+
thread = AzureAIAgentThread(client=self.agent.client, thread_id=thread_id)
44+
asyncio.create_task(thread.delete())
45+
print(f"Thread deleted : {thread_id}")
46+
except Exception as e:
47+
logger.error("Failed to delete thread for key %s: %s", key, e)
48+
return items
49+
50+
def popitem(self):
51+
key, thread_id = super().popitem()
52+
try:
53+
if self.agent:
54+
thread = AzureAIAgentThread(client=self.agent.client, thread_id=thread_id)
55+
asyncio.create_task(thread.delete())
56+
print(f"Thread deleted (LRU evict): {thread_id}")
57+
except Exception as e:
58+
logger.error("Failed to delete thread for key %s (LRU evict): %s", key, e)
59+
return key, thread_id
60+
61+
62+
# Global thread cache, agent will be set later
63+
thread_cache = None
64+
2965

3066
class ChatService:
67+
"""
68+
Service for handling chat interactions, including streaming responses,
69+
processing RAG responses, and generating chart data for visualization.
70+
"""
3171
def __init__(self, request : Request):
3272
config = Config()
3373
self.azure_openai_endpoint = config.azure_openai_endpoint
@@ -37,6 +77,10 @@ def __init__(self, request : Request):
3777
self.azure_ai_project_conn_string = config.azure_ai_project_conn_string
3878
self.agent = request.app.state.agent
3979

80+
global thread_cache
81+
if thread_cache is None:
82+
thread_cache = ExpCache(maxsize=1000, ttl=3600.0, agent=self.agent)
83+
4084
def process_rag_response(self, rag_response, query):
4185
"""
4286
Parses the RAG response dynamically to extract chart data for Chart.js.
@@ -64,7 +108,7 @@ def process_rag_response(self, rag_response, query):
64108
{query}
65109
{rag_response}
66110
"""
67-
logger.info(f">>> Processing chart data for response: {rag_response}")
111+
logger.info(">>> Processing chart data for response: %s", rag_response)
68112

69113
completion = client.chat.completions.create(
70114
model=self.azure_openai_deployment_name,
@@ -76,12 +120,12 @@ def process_rag_response(self, rag_response, query):
76120
)
77121

78122
chart_data = completion.choices[0].message.content.strip().replace("```json", "").replace("```", "")
79-
logger.info(f">>> Generated chart data: {chart_data}")
123+
logger.info(">>> Generated chart data: %s", chart_data)
80124

81125
return json.loads(chart_data)
82126

83127
except Exception as e:
84-
logger.error(f"Error processing RAG response: {e}")
128+
logger.error("Error processing RAG response: %s", e)
85129
return {"error": "Chart could not be generated from this data. Please ask a different question."}
86130

87131
async def stream_openai_text(self, conversation_id: str, query: str) -> StreamingResponse:
@@ -94,44 +138,44 @@ async def stream_openai_text(self, conversation_id: str, query: str) -> Streamin
94138
if not query:
95139
query = "Please provide a query."
96140

97-
# Create the AzureAI Agent
98-
agent = self.agent
99-
100-
thread_id = thread_cache.get(conversation_id, None)
141+
thread_id = None
142+
if thread_cache is not None:
143+
thread_id = thread_cache.get(conversation_id, None)
101144
if thread_id:
102-
thread = AzureAIAgentThread(client=agent.client, thread_id=thread_id)
145+
thread = AzureAIAgentThread(client=self.agent.client, thread_id=thread_id)
103146

104147
truncation_strategy = TruncationObject(type="last_messages", last_messages=2)
105148

106-
async for response in agent.invoke_stream(messages=query, thread=thread, truncation_strategy=truncation_strategy):
107-
thread_cache[conversation_id] = response.thread.id
149+
async for response in self.agent.invoke_stream(messages=query, thread=thread, truncation_strategy=truncation_strategy):
150+
if thread_cache is not None:
151+
thread_cache[conversation_id] = response.thread.id
108152
complete_response += str(response.content)
109153
yield response.content
110154

111155
except RuntimeError as e:
112156
complete_response = str(e)
113157
if "Rate limit is exceeded" in str(e):
114158
logger.error("Rate limit error: %s", e)
115-
raise AgentException(f"Rate limit is exceeded. {str(e)}")
159+
raise AgentException(f"Rate limit is exceeded. {str(e)}") from e
116160
else:
117161
logger.error("RuntimeError: %s", e)
118-
raise AgentException(f"An unexpected runtime error occurred: {str(e)}")
162+
raise AgentException(f"An unexpected runtime error occurred: {str(e)}") from e
119163

120164
except Exception as e:
121165
complete_response = str(e)
122166
logger.error("Error in stream_openai_text: %s", e)
123-
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error streaming OpenAI text")
167+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error streaming OpenAI text") from e
124168

125169
finally:
126170
# Provide a fallback response when no data is received from OpenAI.
127171
if complete_response == "":
128172
logger.info("No response received from OpenAI.")
129-
thread_cache.pop(conversation_id, None)
130-
if thread:
131-
try:
132-
await thread.delete()
133-
except Exception as e:
134-
logger.warning("Failed to delete thread %s: %s", thread_id, e)
173+
thread_id = None
174+
if thread_cache is not None:
175+
thread_id = thread_cache.pop(conversation_id, None)
176+
if thread_id is not None:
177+
corrupt_key = f"{conversation_id}_corrupt_{random.randint(1000, 9999)}"
178+
thread_cache[corrupt_key] = thread_id
135179
yield "I cannot answer this question with the current data. Please rephrase or add more details."
136180

137181
async def stream_chat_request(self, request_body, conversation_id, query):
@@ -186,18 +230,17 @@ async def generate():
186230
error_message = str(e)
187231
retry_after = "sometime"
188232
if "Rate limit is exceeded" in error_message:
189-
import re
190233
match = re.search(r"Try again in (\d+) seconds", error_message)
191234
if match:
192235
retry_after = f"{match.group(1)} seconds"
193-
logger.error(f"Rate limit error: {error_message}")
236+
logger.error("Rate limit error: %s", error_message)
194237
yield json.dumps({"error": f"Rate limit is exceeded. Try again in {retry_after}."}) + "\n\n"
195238
else:
196-
logger.error(f"AgentInvokeException: {error_message}")
239+
logger.error("AgentInvokeException: %s", error_message)
197240
yield json.dumps({"error": "An error occurred. Please try again later."}) + "\n\n"
198241

199242
except Exception as e:
200-
logger.error(f"Error in stream_chat_request: {e}", exc_info=True)
243+
logger.error("Error in stream_chat_request: %s", e, exc_info=True)
201244
yield json.dumps({"error": "An error occurred while processing the request."}) + "\n\n"
202245

203246
return generate()

0 commit comments

Comments
 (0)