1111from azure .cosmos .aio import CosmosClient
1212from azure .cosmos .aio ._database import DatabaseProxy
1313from azure .cosmos .exceptions import CosmosResourceExistsError
14- from pytest import Session
14+ import v3 . models . messages as messages
1515
1616from common .models .messages_kernel import (
1717 AgentMessage ,
2424from .database_base import DatabaseBase
2525from ..models .messages_kernel import (
2626 BaseDataModel ,
27- Session ,
2827 Plan ,
2928 Step ,
3029 AgentMessage ,
@@ -38,7 +37,6 @@ class CosmosDBClient(DatabaseBase):
3837 """CosmosDB implementation of the database interface."""
3938
4039 MODEL_CLASS_MAPPING = {
41- DataType .session : Session ,
4240 DataType .plan : Plan ,
4341 DataType .step : Step ,
4442 DataType .agent_message : AgentMessage ,
@@ -200,17 +198,6 @@ async def update_plan(self, plan: Plan) -> None:
200198 """Update a plan in CosmosDB."""
201199 await self .update_item (plan )
202200
203- async def get_plan_by_session (self , session_id : str ) -> Optional [Plan ]:
204- """Retrieve a plan by session_id."""
205- query = (
206- "SELECT * FROM c WHERE c.session_id=@session_id AND c.data_type=@data_type"
207- )
208- parameters = [
209- {"name" : "@session_id" , "value" : session_id },
210- {"name" : "@data_type" , "value" : DataType .plan },
211- ]
212- results = await self .query_items (query , parameters , Plan )
213- return results [0 ] if results else None
214201
215202 async def get_plan_by_plan_id (self , plan_id : str ) -> Optional [Plan ]:
216203 """Retrieve a plan by plan_id."""
@@ -360,27 +347,6 @@ async def delete_team(self, team_id: str) -> bool:
360347 logging .exception (f"Failed to delete team from Cosmos DB: { e } " )
361348 return False
362349
363- async def get_data_by_type_and_session_id (
364- self , data_type : str , session_id : str
365- ) -> List [BaseDataModel ]:
366- """Query the Cosmos DB for documents with the matching data_type, session_id and user_id."""
367- await self ._ensure_initialized ()
368- if self .container is None :
369- return []
370-
371- model_class = self .MODEL_CLASS_MAPPING .get (data_type , BaseDataModel )
372- try :
373- query = "SELECT * FROM c WHERE c.session_id=@session_id AND c.user_id=@user_id AND c.data_type=@data_type ORDER BY c._ts ASC"
374- parameters = [
375- {"name" : "@session_id" , "value" : session_id },
376- {"name" : "@data_type" , "value" : data_type },
377- {"name" : "@user_id" , "value" : self .user_id },
378- ]
379- return await self .query_items (query , parameters , model_class )
380- except Exception as e :
381- logging .exception (f"Failed to query data by type from Cosmos DB: { e } " )
382- return []
383-
384350 # Data Management Operations
385351 async def get_data_by_type (self , data_type : str ) -> List [BaseDataModel ]:
386352 """Retrieve all data of a specific type."""
@@ -477,3 +443,40 @@ async def update_current_team(self, current_team: UserCurrentTeam) -> None:
477443 """Update the current team for a user."""
478444 await self ._ensure_initialized ()
479445 await self .update_item (current_team )
446+
447+ async def delete_plan_by_plan_id (self , plan_id : str ) -> bool :
448+ """Delete a plan by its ID."""
449+ query = "SELECT c.id, c.session_id FROM c WHERE c.id=@plan_id "
450+
451+ params = [
452+ {"name" : "@plan_id" , "value" : plan_id },
453+ ]
454+ items = self .container .query_items (query = query , parameters = params )
455+ print ("Items to delete planid:" , items )
456+ if items :
457+ async for doc in items :
458+ try :
459+ await self .container .delete_item (doc ["id" ], partition_key = doc ["session_id" ])
460+ except Exception as e :
461+ self .logger .warning ("Failed deleting current team doc %s: %s" , doc .get ("id" ), e )
462+
463+ return True
464+
465+ async def add_mplan (self , mplan : messages .MPlan ) -> None :
466+ """Add a team configuration to the database."""
467+ await self .add_item (mplan )
468+
469+ async def update_mplan (self , mplan : messages .MPlan ) -> None :
470+ """Update a team configuration in the database."""
471+ await self .update_item (mplan )
472+
473+
474+ async def get_mplan (self , plan_id : str ) -> Optional [messages .MPlan ]:
475+ """Retrieve a mplan configuration by mplan_id."""
476+ query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type"
477+ parameters = [
478+ {"name" : "@plan_id" , "value" : plan_id },
479+ {"name" : "@data_type" , "value" : DataType .m_plan },
480+ ]
481+ results = await self .query_items (query , parameters , messages .MPlan )
482+ return results [0 ] if results else None
0 commit comments