1111from azure .cosmos .aio import CosmosClient
1212from azure .cosmos .aio ._database import DatabaseProxy
1313from azure .cosmos .exceptions import CosmosResourceExistsError
14- from pytest import Session
14+ import v3 . models . messages as messages
1515
1616from common .models .messages_kernel import (
1717 AgentMessage ,
2323
2424from .database_base import DatabaseBase
2525from ..models .messages_kernel import (
26+ AgentMessageData ,
2627 BaseDataModel ,
27- Session ,
2828 Plan ,
2929 Step ,
3030 AgentMessage ,
@@ -38,7 +38,6 @@ class CosmosDBClient(DatabaseBase):
3838 """CosmosDB implementation of the database interface."""
3939
4040 MODEL_CLASS_MAPPING = {
41- DataType .session : Session ,
4241 DataType .plan : Plan ,
4342 DataType .step : Step ,
4443 DataType .agent_message : AgentMessage ,
@@ -190,29 +189,6 @@ async def delete_item(self, item_id: str, partition_key: str) -> None:
190189 self .logger .error ("Failed to delete item from CosmosDB: %s" , str (e ))
191190 raise
192191
193- # Session Operations
194- async def add_session (self , session : Session ) -> None :
195- """Add a session to CosmosDB."""
196- await self .add_item (session )
197-
198- async def get_session (self , session_id : str ) -> Optional [Session ]:
199- """Retrieve a session by session_id."""
200- query = "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type"
201- parameters = [
202- {"name" : "@id" , "value" : session_id },
203- {"name" : "@data_type" , "value" : DataType .session },
204- ]
205- results = await self .query_items (query , parameters , Session )
206- return results [0 ] if results else None
207-
208- async def get_all_sessions (self ) -> List [Session ]:
209- """Retrieve all sessions for the user."""
210- query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
211- parameters = [
212- {"name" : "@user_id" , "value" : self .user_id },
213- {"name" : "@data_type" , "value" : DataType .session },
214- ]
215- return await self .query_items (query , parameters , Session )
216192
217193 # Plan Operations
218194 async def add_plan (self , plan : Plan ) -> None :
@@ -223,17 +199,6 @@ async def update_plan(self, plan: Plan) -> None:
223199 """Update a plan in CosmosDB."""
224200 await self .update_item (plan )
225201
226- async def get_plan_by_session (self , session_id : str ) -> Optional [Plan ]:
227- """Retrieve a plan by session_id."""
228- query = (
229- "SELECT * FROM c WHERE c.session_id=@session_id AND c.data_type=@data_type"
230- )
231- parameters = [
232- {"name" : "@session_id" , "value" : session_id },
233- {"name" : "@data_type" , "value" : DataType .plan },
234- ]
235- results = await self .query_items (query , parameters , Plan )
236- return results [0 ] if results else None
237202
238203 async def get_plan_by_plan_id (self , plan_id : str ) -> Optional [Plan ]:
239204 """Retrieve a plan by plan_id."""
@@ -272,7 +237,7 @@ async def get_all_plans_by_team_id(self, team_id: str) -> List[Plan]:
272237
273238 async def get_all_plans_by_team_id_status (self , team_id : str , status : str ) -> List [Plan ]:
274239 """Retrieve all plans for a specific team."""
275- query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type and c.user_id=@user_id and c.overall_status=@status"
240+ query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type and c.user_id=@user_id and c.overall_status=@status ORDER BY c._ts DESC "
276241 parameters = [
277242 {"name" : "@user_id" , "value" : self .user_id },
278243 {"name" : "@team_id" , "value" : team_id },
@@ -328,7 +293,7 @@ async def get_team(self, team_id: str) -> Optional[TeamConfiguration]:
328293 teams = await self .query_items (query , parameters , TeamConfiguration )
329294 return teams [0 ] if teams else None
330295
331- async def get_team_by_id (self , id : str ) -> Optional [TeamConfiguration ]:
296+ async def get_team_by_id (self , team_id : str ) -> Optional [TeamConfiguration ]:
332297 """Retrieve a specific team configuration by its document id.
333298
334299 Args:
@@ -337,9 +302,9 @@ async def get_team_by_id(self, id: str) -> Optional[TeamConfiguration]:
337302 Returns:
338303 TeamConfiguration object or None if not found
339304 """
340- query = "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type"
305+ query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type"
341306 parameters = [
342- {"name" : "@id " , "value" : id },
307+ {"name" : "@team_id " , "value" : team_id },
343308 {"name" : "@data_type" , "value" : DataType .team_config },
344309 ]
345310 teams = await self .query_items (query , parameters , TeamConfiguration )
@@ -383,27 +348,6 @@ async def delete_team(self, team_id: str) -> bool:
383348 logging .exception (f"Failed to delete team from Cosmos DB: { e } " )
384349 return False
385350
386- async def get_data_by_type_and_session_id (
387- self , data_type : str , session_id : str
388- ) -> List [BaseDataModel ]:
389- """Query the Cosmos DB for documents with the matching data_type, session_id and user_id."""
390- await self ._ensure_initialized ()
391- if self .container is None :
392- return []
393-
394- model_class = self .MODEL_CLASS_MAPPING .get (data_type , BaseDataModel )
395- try :
396- query = "SELECT * FROM c WHERE c.session_id=@session_id AND c.user_id=@user_id AND c.data_type=@data_type ORDER BY c._ts ASC"
397- parameters = [
398- {"name" : "@session_id" , "value" : session_id },
399- {"name" : "@data_type" , "value" : data_type },
400- {"name" : "@user_id" , "value" : self .user_id },
401- ]
402- return await self .query_items (query , parameters , model_class )
403- except Exception as e :
404- logging .exception (f"Failed to query data by type from Cosmos DB: { e } " )
405- return []
406-
407351 # Data Management Operations
408352 async def get_data_by_type (self , data_type : str ) -> List [BaseDataModel ]:
409353 """Retrieve all data of a specific type."""
@@ -470,13 +414,89 @@ async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:
470414 teams = await self .query_items (query , parameters , UserCurrentTeam )
471415 return teams [0 ] if teams else None
472416
417+
418+
419+ async def delete_current_team (self , user_id : str ) -> bool :
420+ """Delete the current team for a user."""
421+ query = "SELECT c.id, c.session_id FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
422+
423+ params = [
424+ {"name" : "@user_id" , "value" : user_id },
425+ {"name" : "@data_type" , "value" : DataType .user_current_team },
426+ ]
427+ items = self .container .query_items (query = query , parameters = params )
428+ print ("Items to delete:" , items )
429+ if items :
430+ async for doc in items :
431+ try :
432+ await self .container .delete_item (doc ["id" ], partition_key = doc ["session_id" ])
433+ except Exception as e :
434+ self .logger .warning ("Failed deleting current team doc %s: %s" , doc .get ("id" ), e )
435+
436+ return True
437+
473438 async def set_current_team (self , current_team : UserCurrentTeam ) -> None :
474439 """Set the current team for a user."""
475440 await self ._ensure_initialized ()
476441 await self .add_item (current_team )
477442
478-
479443 async def update_current_team (self , current_team : UserCurrentTeam ) -> None :
480444 """Update the current team for a user."""
481445 await self ._ensure_initialized ()
482446 await self .update_item (current_team )
447+
448+ async def delete_plan_by_plan_id (self , plan_id : str ) -> bool :
449+ """Delete a plan by its ID."""
450+ query = "SELECT c.id, c.session_id FROM c WHERE c.id=@plan_id "
451+
452+ params = [
453+ {"name" : "@plan_id" , "value" : plan_id },
454+ ]
455+ items = self .container .query_items (query = query , parameters = params )
456+ print ("Items to delete planid:" , items )
457+ if items :
458+ async for doc in items :
459+ try :
460+ await self .container .delete_item (doc ["id" ], partition_key = doc ["session_id" ])
461+ except Exception as e :
462+ self .logger .warning ("Failed deleting current team doc %s: %s" , doc .get ("id" ), e )
463+
464+ return True
465+
466+ async def add_mplan (self , mplan : messages .MPlan ) -> None :
467+ """Add a team configuration to the database."""
468+ await self .add_item (mplan )
469+
470+ async def update_mplan (self , mplan : messages .MPlan ) -> None :
471+ """Update a team configuration in the database."""
472+ await self .update_item (mplan )
473+
474+
475+ async def get_mplan (self , plan_id : str ) -> Optional [messages .MPlan ]:
476+ """Retrieve a mplan configuration by mplan_id."""
477+ query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type"
478+ parameters = [
479+ {"name" : "@plan_id" , "value" : plan_id },
480+ {"name" : "@data_type" , "value" : DataType .m_plan },
481+ ]
482+ results = await self .query_items (query , parameters , messages .MPlan )
483+ return results [0 ] if results else None
484+
485+
486+ async def add_agent_message (self , message : AgentMessageData ) -> None :
487+ """Add an agent message to the database."""
488+ await self .add_item (message )
489+
490+ async def update_agent_message (self , message : AgentMessageData ) -> None :
491+ """Update an agent message in the database."""
492+ await self .update_item (message )
493+
494+ async def get_agent_messages (self , plan_id : str ) -> List [AgentMessageData ]:
495+ """Retrieve an agent message by message_id."""
496+ query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type ORDER BY c._ts ASC"
497+ parameters = [
498+ {"name" : "@plan_id" , "value" : plan_id },
499+ {"name" : "@data_type" , "value" : DataType .m_plan_message },
500+ ]
501+
502+ return await self .query_items (query , parameters , AgentMessageData )
0 commit comments