44from azure .cosmos import exceptions
55
66
7- class CosmosConversationClient ():
8-
9- def __init__ (self , cosmosdb_endpoint : str , credential : any , database_name : str , container_name : str , enable_message_feedback : bool = False ):
7+ class CosmosConversationClient :
8+ def __init__ (
9+ self ,
10+ cosmosdb_endpoint : str ,
11+ credential : any ,
12+ database_name : str ,
13+ container_name : str ,
14+ enable_message_feedback : bool = False ,
15+ ):
1016 self .cosmosdb_endpoint = cosmosdb_endpoint
1117 self .credential = credential
1218 self .database_name = database_name
1319 self .container_name = container_name
1420 self .enable_message_feedback = enable_message_feedback
1521 try :
16- self .cosmosdb_client = CosmosClient (self .cosmosdb_endpoint , credential = credential )
22+ self .cosmosdb_client = CosmosClient (
23+ self .cosmosdb_endpoint , credential = credential
24+ )
1725 except exceptions .CosmosHttpResponseError as e :
1826 if e .status_code == 401 :
1927 raise ValueError ("Invalid credentials" ) from e
2028 else :
2129 raise ValueError ("Invalid CosmosDB endpoint" ) from e
2230
2331 try :
24- self .database_client = self .cosmosdb_client .get_database_client (database_name )
32+ self .database_client = self .cosmosdb_client .get_database_client (
33+ database_name
34+ )
2535 except exceptions .CosmosResourceNotFoundError :
2636 raise ValueError ("Invalid CosmosDB database name" )
2737
2838 try :
29- self .container_client = self .database_client .get_container_client (container_name )
39+ self .container_client = self .database_client .get_container_client (
40+ container_name
41+ )
3042 except exceptions .CosmosResourceNotFoundError :
3143 raise ValueError ("Invalid CosmosDB container name" )
3244
3345 async def ensure (self ):
34- if not self .cosmosdb_client or not self .database_client or not self .container_client :
46+ if (
47+ not self .cosmosdb_client
48+ or not self .database_client
49+ or not self .container_client
50+ ):
3551 return False , "CosmosDB client not initialized correctly"
3652 try :
3753 database_info = await self .database_client .read ()
3854 except :
39- return False , f"CosmosDB database { self .database_name } on account { self .cosmosdb_endpoint } not found"
55+ return (
56+ False ,
57+ f"CosmosDB database { self .database_name } on account { self .cosmosdb_endpoint } not found" ,
58+ )
4059
4160 try :
4261 container_info = await self .container_client .read ()
@@ -45,14 +64,14 @@ async def ensure(self):
4564
4665 return True , "CosmosDB client initialized successfully"
4766
48- async def create_conversation (self , user_id , title = '' ):
67+ async def create_conversation (self , user_id , title = "" ):
4968 conversation = {
50- 'id' : str (uuid .uuid4 ()),
51- ' type' : ' conversation' ,
52- ' createdAt' : datetime .utcnow ().isoformat (),
53- ' updatedAt' : datetime .utcnow ().isoformat (),
54- ' userId' : user_id ,
55- ' title' : title
69+ "id" : str (uuid .uuid4 ()),
70+ " type" : " conversation" ,
71+ " createdAt" : datetime .utcnow ().isoformat (),
72+ " updatedAt" : datetime .utcnow ().isoformat (),
73+ " userId" : user_id ,
74+ " title" : title ,
5675 }
5776 # TODO: add some error handling based on the output of the upsert_item call
5877 resp = await self .container_client .upsert_item (conversation )
@@ -69,9 +88,13 @@ async def upsert_conversation(self, conversation):
6988 return False
7089
7190 async def delete_conversation (self , user_id , conversation_id ):
72- conversation = await self .container_client .read_item (item = conversation_id , partition_key = user_id )
91+ conversation = await self .container_client .read_item (
92+ item = conversation_id , partition_key = user_id
93+ )
7394 if conversation :
74- resp = await self .container_client .delete_item (item = conversation_id , partition_key = user_id )
95+ resp = await self .container_client .delete_item (
96+ item = conversation_id , partition_key = user_id
97+ )
7598 return resp
7699 else :
77100 return True
@@ -82,41 +105,36 @@ async def delete_messages(self, conversation_id, user_id):
82105 response_list = []
83106 if messages :
84107 for message in messages :
85- resp = await self .container_client .delete_item (item = message ['id' ], partition_key = user_id )
108+ resp = await self .container_client .delete_item (
109+ item = message ["id" ], partition_key = user_id
110+ )
86111 response_list .append (resp )
87112 return response_list
88113
89- async def get_conversations (self , user_id , limit , sort_order = 'DESC' , offset = 0 ):
90- parameters = [
91- {
92- 'name' : '@userId' ,
93- 'value' : user_id
94- }
95- ]
114+ async def get_conversations (self , user_id , limit , sort_order = "DESC" , offset = 0 ):
115+ parameters = [{"name" : "@userId" , "value" : user_id }]
96116 query = f"SELECT * FROM c where c.userId = @userId and c.type='conversation' order by c.updatedAt { sort_order } "
97117 if limit is not None :
98118 query += f" offset { offset } limit { limit } "
99119
100120 conversations = []
101- async for item in self .container_client .query_items (query = query , parameters = parameters ):
121+ async for item in self .container_client .query_items (
122+ query = query , parameters = parameters
123+ ):
102124 conversations .append (item )
103125
104126 return conversations
105127
106128 async def get_conversation (self , user_id , conversation_id ):
107129 parameters = [
108- {
109- 'name' : '@conversationId' ,
110- 'value' : conversation_id
111- },
112- {
113- 'name' : '@userId' ,
114- 'value' : user_id
115- }
130+ {"name" : "@conversationId" , "value" : conversation_id },
131+ {"name" : "@userId" , "value" : user_id },
116132 ]
117133 query = f"SELECT * FROM c where c.id = @conversationId and c.type='conversation' and c.userId = @userId"
118134 conversations = []
119- async for item in self .container_client .query_items (query = query , parameters = parameters ):
135+ async for item in self .container_client .query_items (
136+ query = query , parameters = parameters
137+ ):
120138 conversations .append (item )
121139
122140 # if no conversations are found, return None
@@ -127,54 +145,52 @@ async def get_conversation(self, user_id, conversation_id):
127145
128146 async def create_message (self , uuid , conversation_id , user_id , input_message : dict ):
129147 message = {
130- 'id' : uuid ,
131- ' type' : ' message' ,
132- ' userId' : user_id ,
133- ' createdAt' : datetime .utcnow ().isoformat (),
134- ' updatedAt' : datetime .utcnow ().isoformat (),
135- ' conversationId' : conversation_id ,
136- ' role' : input_message [' role' ],
137- ' content' : input_message [' content' ]
148+ "id" : uuid ,
149+ " type" : " message" ,
150+ " userId" : user_id ,
151+ " createdAt" : datetime .utcnow ().isoformat (),
152+ " updatedAt" : datetime .utcnow ().isoformat (),
153+ " conversationId" : conversation_id ,
154+ " role" : input_message [" role" ],
155+ " content" : input_message [" content" ],
138156 }
139157
140158 if self .enable_message_feedback :
141- message [' feedback' ] = ''
159+ message [" feedback" ] = ""
142160
143161 resp = await self .container_client .upsert_item (message )
144162 if resp :
145163 # update the parent conversations's updatedAt field with the current message's createdAt datetime value
146164 conversation = await self .get_conversation (user_id , conversation_id )
147165 if not conversation :
148166 return "Conversation not found"
149- conversation [' updatedAt' ] = message [' createdAt' ]
167+ conversation [" updatedAt" ] = message [" createdAt" ]
150168 await self .upsert_conversation (conversation )
151169 return resp
152170 else :
153171 return False
154172
155173 async def update_message_feedback (self , user_id , message_id , feedback ):
156- message = await self .container_client .read_item (item = message_id , partition_key = user_id )
174+ message = await self .container_client .read_item (
175+ item = message_id , partition_key = user_id
176+ )
157177 if message :
158- message [' feedback' ] = feedback
178+ message [" feedback" ] = feedback
159179 resp = await self .container_client .upsert_item (message )
160180 return resp
161181 else :
162182 return False
163183
164184 async def get_messages (self , user_id , conversation_id ):
165185 parameters = [
166- {
167- 'name' : '@conversationId' ,
168- 'value' : conversation_id
169- },
170- {
171- 'name' : '@userId' ,
172- 'value' : user_id
173- }
186+ {"name" : "@conversationId" , "value" : conversation_id },
187+ {"name" : "@userId" , "value" : user_id },
174188 ]
175189 query = f"SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type='message' AND c.userId = @userId ORDER BY c.timestamp ASC"
176190 messages = []
177- async for item in self .container_client .query_items (query = query , parameters = parameters ):
191+ async for item in self .container_client .query_items (
192+ query = query , parameters = parameters
193+ ):
178194 messages .append (item )
179195
180196 return messages
0 commit comments