Skip to content

Commit 02bf49b

Browse files
committed
Move code that saves/loads chat session state into provider class
1 parent c0a0882 commit 02bf49b

File tree

5 files changed

+169
-81
lines changed

5 files changed

+169
-81
lines changed

Backend/api_models/chat_session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pydantic import BaseModel, Field
2+
from typing import List
3+
4+
class ChatSession(BaseModel):
5+
id: str # The session ID
6+
title: str # The title of the chat session
7+
history: List[dict] = Field(default_factory=list) # The chat history
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from pydantic import BaseModel
2+
3+
# Define the model for a Chat Session response
4+
class ChatSessionResponse(BaseModel):
5+
session_id: str
6+
title: str

Backend/app.py

Lines changed: 20 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from fastapi import FastAPI, HTTPException
55
from fastapi.middleware.cors import CORSMiddleware
66

7+
from typing import List
8+
from chat_session_state.cosmosdb_chat_session_state_provider import CosmosDBChatSessionStateProvider
9+
from api_models.chat_session_request import ChatSessionResponse
10+
711
import uuid
812

913
from api_models.ai_request import AIRequest
@@ -58,64 +62,32 @@ def run_cosmic_works_ai_agent(request: AIRequest):
5862
# ========================
5963
# Chat Session State / History Support is below:
6064
# ========================
61-
import os
62-
from dotenv import load_dotenv
63-
from azure.cosmos import CosmosClient
64-
from pydantic import BaseModel
65-
from typing import List
66-
67-
load_dotenv()
6865

69-
# Your existing Cosmos DB client and container setup
70-
CONNECTION_STRING = os.environ.get("COSMOS_DB_CONNECTION_STRING")
71-
client = CosmosClient.from_connection_string(CONNECTION_STRING)
72-
db = client.get_database_client("cosmic_works")
73-
chat_session_container = db.get_container_client("chat_session")
66+
# Create an instance of the CosmosDBChatSessionStateProvider class
67+
cosmos_provider = CosmosDBChatSessionStateProvider()
7468

75-
# Define the model for a Chat Session response
76-
class ChatSessionResponse(BaseModel):
77-
session_id: str
78-
title: str
79-
80-
@app.get("/session/list") #, response_model=List[ChatSessionResponse])
69+
@app.get("/session/list", response_model=List[ChatSessionResponse])
8170
def list_sessions():
8271
"""
8372
Endpoint to list all chat sessions.
8473
"""
8574
try:
86-
# Query to get all sessions in the chat_session_container
87-
query = "SELECT c.id, c.title FROM c"
88-
sessions = list(chat_session_container.query_items(
89-
query=query,
90-
enable_cross_partition_query=True
91-
))
92-
93-
# Convert the sessions into a list of ChatSessionResponse objects
94-
session_responses = [ChatSessionResponse(session_id=session['id'], title=session['title']) for session in sessions]
95-
return session_responses
96-
except Exception as e:
97-
raise HTTPException(status_code=500, detail=f"Failed to retrieve sessions: {str(e)}")
98-
99-
# GET /session/load/{session_id}
75+
return cosmos_provider.list_sessions()
76+
except RuntimeError as e:
77+
# Return an internal server error if a runtime error occurs
78+
raise HTTPException(status_code=500, detail=str(e))
79+
80+
10081
@app.get("/session/load/{session_id}")
10182
def load_session(session_id: str):
10283
"""
10384
Endpoint to load a chat session by session_id.
10485
"""
10586
try:
106-
# Query to get the chat session with the provided session_id
107-
query = f"SELECT * FROM c WHERE c.id = '{session_id}'"
108-
session = list(chat_session_container.query_items(
109-
query=query,
110-
enable_cross_partition_query=True
111-
))
112-
113-
# If the session exists, return it
114-
if session:
115-
return session[0]
116-
else:
117-
raise HTTPException(status_code=404, detail="Session not found")
118-
except Exception as e:
119-
raise HTTPException(status_code=500, detail=f"Failed to retrieve session: {str(e)}")
120-
121-
87+
return cosmos_provider.load_session(session_id)
88+
except ValueError as e:
89+
# Return a 404 error if the session is not found
90+
raise HTTPException(status_code=404, detail=str(e))
91+
except RuntimeError as e:
92+
# Return an internal server error if a runtime error occurs
93+
raise HTTPException(status_code=500, detail=str(e))
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import os
2+
from datetime import datetime
3+
from typing import List, Optional
4+
from azure.cosmos import CosmosClient, PartitionKey, exceptions as cosmos_exceptions
5+
from dotenv import load_dotenv
6+
7+
from api_models.chat_session_request import ChatSessionResponse
8+
from api_models.chat_session import ChatSession
9+
10+
# Load environment variables
11+
load_dotenv()
12+
13+
# Initialize Cosmos DB client and container globally within the module
14+
CONNECTION_STRING = os.environ.get("COSMOS_DB_CONNECTION_STRING")
15+
client = CosmosClient.from_connection_string(CONNECTION_STRING)
16+
db = client.get_database_client("cosmic_works")
17+
18+
# Initialize the chat session container, create if not exists
19+
db.create_container_if_not_exists(id="chat_session", partition_key=PartitionKey(path="/id"))
20+
21+
chat_session_container = db.get_container_client("chat_session")
22+
23+
24+
class CosmosDBChatSessionStateProvider:
25+
"""
26+
A class to encapsulate CRUD operations for interacting with the chat session state in Cosmos DB.
27+
"""
28+
29+
def __init__(self, container=chat_session_container):
30+
self.container = container
31+
32+
def list_sessions(self) -> List[ChatSessionResponse]:
33+
"""
34+
Lists all chat sessions from the chat session container.
35+
36+
Returns:
37+
List[ChatSessionResponse]: A list of chat session responses.
38+
"""
39+
try:
40+
query = "SELECT c.id, c.title FROM c"
41+
sessions = list(self.container.query_items(
42+
query=query,
43+
enable_cross_partition_query=True
44+
))
45+
46+
# Convert the sessions into a list of ChatSessionResponse objects
47+
session_responses = [
48+
ChatSessionResponse(session_id=session['id'], title=session['title'])
49+
for session in sessions
50+
]
51+
return session_responses
52+
except cosmos_exceptions.CosmosHttpResponseError as e:
53+
raise RuntimeError(f"Failed to retrieve sessions: {str(e)}")
54+
55+
def load_or_create_chat_session(self, session_id: str) -> ChatSession:
56+
"""
57+
Load an existing session from the Cosmos DB container, or create a new one if not found.
58+
"""
59+
try:
60+
# Try to read the session from Cosmos DB
61+
session_item = chat_session_container.read_item(item=session_id, partition_key=session_id)
62+
return ChatSession(**session_item)
63+
except Exception:
64+
# If the session is not found, create a new one
65+
new_session = ChatSession(
66+
id=session_id,
67+
session_id=session_id,
68+
title=f"{datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')}",
69+
chat_history=[]
70+
)
71+
chat_session_container.upsert_item(new_session.model_dump())
72+
return new_session
73+
74+
def load_session(self, session_id: str) -> Optional[dict]:
75+
"""
76+
Loads a chat session by session ID.
77+
78+
Args:
79+
session_id (str): The ID of the session to be loaded.
80+
81+
Returns:
82+
Optional[dict]: The chat session data if found, else None.
83+
"""
84+
try:
85+
query = f"SELECT * FROM c WHERE c.id = '{session_id}'"
86+
session = list(self.container.query_items(
87+
query=query,
88+
enable_cross_partition_query=True
89+
))
90+
91+
if session:
92+
return session[0]
93+
else:
94+
raise ValueError("Session not found")
95+
except cosmos_exceptions.CosmosHttpResponseError as e:
96+
raise RuntimeError(f"Failed to retrieve session: {str(e)}")
97+
98+
def upsert_session(self, session: ChatSession) -> dict:
99+
"""
100+
Creates or updates a chat session in the chat session container.
101+
102+
Args:
103+
session: The chat session to create or update.
104+
105+
Returns:
106+
dict: The upserted session data.
107+
"""
108+
try:
109+
response = self.container.upsert_item(session.model_dump())
110+
return response
111+
except cosmos_exceptions.CosmosHttpResponseError as e:
112+
raise RuntimeError(f"Failed to create or update session: {str(e)}")
113+
114+
# def delete_session(self, session_id: str) -> None:
115+
# """
116+
# Deletes a chat session by session ID.
117+
118+
# Args:
119+
# session_id (str): The ID of the session to delete.
120+
# """
121+
# try:
122+
# self.container.delete_item(item=session_id, partition_key=session_id)
123+
# except cosmos_exceptions.CosmosResourceNotFoundError:
124+
# raise ValueError(f"Session with ID '{session_id}' not found")
125+
# except cosmos_exceptions.CosmosHttpResponseError as e:
126+
# raise RuntimeError(f"Failed to delete session: {str(e)}")

Backend/cosmic_works/cosmic_works_ai_agent.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,20 @@
77
"""
88
import os
99
import json
10-
from datetime import datetime
11-
from pydantic import BaseModel, Field
12-
from typing import Type, TypeVar, List
10+
from pydantic import BaseModel
11+
from typing import Type, TypeVar
1312
from dotenv import load_dotenv
1413
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
15-
from azure.cosmos import CosmosClient, ContainerProxy, PartitionKey
14+
from azure.cosmos import CosmosClient, ContainerProxy
1615
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
1716
from langchain_core.tools import StructuredTool
1817
from langchain.agents.agent_toolkits import create_retriever_tool
1918
from langchain.agents import AgentExecutor, create_openai_functions_agent
2019
from models import Product, SalesOrder
2120
from retrievers import AzureCosmosDBNoSQLRetriever
2221

22+
from chat_session_state.cosmosdb_chat_session_state_provider import CosmosDBChatSessionStateProvider
23+
2324
T = TypeVar('T', bound=BaseModel)
2425

2526
# Load settings for the notebook
@@ -37,15 +38,10 @@
3738
product_v_container = db.get_container_client("product_v")
3839
sales_order_container = db.get_container_client("salesOrder")
3940

40-
# Initialize the chat session container, create if not exists
41-
db.create_container_if_not_exists(id="chat_session", partition_key=PartitionKey(path="/id"))
42-
chat_session_container = db.get_container_client("chat_session")
43-
41+
# Create an instance of the CosmosDBChatSessionStateProvider class
42+
# This will be used to load or create Chat Sessions
43+
chat_session_state_provider = CosmosDBChatSessionStateProvider()
4444

45-
class ChatSession(BaseModel):
46-
id: str # The session ID
47-
title: str # The title of the chat session
48-
history: List[dict] = Field(default_factory=list) # The chat history
4945

5046
class CosmicWorksAIAgent:
5147
"""
@@ -56,7 +52,7 @@ class CosmicWorksAIAgent:
5652
def __init__(self, session_id: str):
5753
self.session_id = session_id
5854

59-
self.chat_session = self.load_or_create_chat_session(session_id)
55+
self.chat_session = chat_session_state_provider.load_or_create_chat_session(session_id)
6056

6157
llm = AzureChatOpenAI(
6258
temperature = 0,
@@ -128,28 +124,9 @@ def run(self, prompt: str) -> str:
128124
self.chat_session.history.append({"role": "assistant", "content": response})
129125

130126
# Save updated session chat history to Cosmos DB
131-
chat_session_container.upsert_item(self.chat_session.model_dump())
127+
chat_session_state_provider.upsert_session(self.chat_session)
132128

133129
return response
134-
135-
def load_or_create_chat_session(self, session_id: str) -> ChatSession:
136-
"""
137-
Load an existing session from the Cosmos DB container, or create a new one if not found.
138-
"""
139-
try:
140-
# Try to read the session from Cosmos DB
141-
session_item = chat_session_container.read_item(item=session_id, partition_key=session_id)
142-
return ChatSession(**session_item)
143-
except Exception:
144-
# If the session is not found, create a new one
145-
new_session = ChatSession(
146-
id=session_id,
147-
session_id=session_id,
148-
title=f"{datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')}",
149-
chat_history=[]
150-
)
151-
chat_session_container.upsert_item(new_session.model_dump())
152-
return new_session
153130

154131
# Tools helper methods
155132
def delete_attribute_by_alias(instance: BaseModel, alias:str):

0 commit comments

Comments
 (0)