Skip to content

Commit bcf5c22

Browse files
Merge pull request #480 from microsoft/macae-v3-fr-dev-92
Macae v3 fr dev 92
2 parents d013efd + 5007810 commit bcf5c22

File tree

22 files changed

+1237
-646
lines changed

22 files changed

+1237
-646
lines changed

src/backend/common/database/cosmosdb.py

Lines changed: 83 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from azure.cosmos.aio import CosmosClient
1212
from azure.cosmos.aio._database import DatabaseProxy
1313
from azure.cosmos.exceptions import CosmosResourceExistsError
14-
from pytest import Session
14+
import v3.models.messages as messages
1515

1616
from common.models.messages_kernel import (
1717
AgentMessage,
@@ -23,8 +23,8 @@
2323

2424
from .database_base import DatabaseBase
2525
from ..models.messages_kernel import (
26+
AgentMessageData,
2627
BaseDataModel,
27-
Session,
2828
Plan,
2929
Step,
3030
AgentMessage,
@@ -38,7 +38,6 @@ class CosmosDBClient(DatabaseBase):
3838
"""CosmosDB implementation of the database interface."""
3939

4040
MODEL_CLASS_MAPPING = {
41-
DataType.session: Session,
4241
DataType.plan: Plan,
4342
DataType.step: Step,
4443
DataType.agent_message: AgentMessage,
@@ -190,29 +189,6 @@ async def delete_item(self, item_id: str, partition_key: str) -> None:
190189
self.logger.error("Failed to delete item from CosmosDB: %s", str(e))
191190
raise
192191

193-
# Session Operations
194-
async def add_session(self, session: Session) -> None:
195-
"""Add a session to CosmosDB."""
196-
await self.add_item(session)
197-
198-
async def get_session(self, session_id: str) -> Optional[Session]:
199-
"""Retrieve a session by session_id."""
200-
query = "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type"
201-
parameters = [
202-
{"name": "@id", "value": session_id},
203-
{"name": "@data_type", "value": DataType.session},
204-
]
205-
results = await self.query_items(query, parameters, Session)
206-
return results[0] if results else None
207-
208-
async def get_all_sessions(self) -> List[Session]:
209-
"""Retrieve all sessions for the user."""
210-
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
211-
parameters = [
212-
{"name": "@user_id", "value": self.user_id},
213-
{"name": "@data_type", "value": DataType.session},
214-
]
215-
return await self.query_items(query, parameters, Session)
216192

217193
# Plan Operations
218194
async def add_plan(self, plan: Plan) -> None:
@@ -223,17 +199,6 @@ async def update_plan(self, plan: Plan) -> None:
223199
"""Update a plan in CosmosDB."""
224200
await self.update_item(plan)
225201

226-
async def get_plan_by_session(self, session_id: str) -> Optional[Plan]:
227-
"""Retrieve a plan by session_id."""
228-
query = (
229-
"SELECT * FROM c WHERE c.session_id=@session_id AND c.data_type=@data_type"
230-
)
231-
parameters = [
232-
{"name": "@session_id", "value": session_id},
233-
{"name": "@data_type", "value": DataType.plan},
234-
]
235-
results = await self.query_items(query, parameters, Plan)
236-
return results[0] if results else None
237202

238203
async def get_plan_by_plan_id(self, plan_id: str) -> Optional[Plan]:
239204
"""Retrieve a plan by plan_id."""
@@ -272,7 +237,7 @@ async def get_all_plans_by_team_id(self, team_id: str) -> List[Plan]:
272237

273238
async def get_all_plans_by_team_id_status(self, team_id: str, status: str) -> List[Plan]:
274239
"""Retrieve all plans for a specific team."""
275-
query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type and c.user_id=@user_id and c.overall_status=@status"
240+
query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type and c.user_id=@user_id and c.overall_status=@status ORDER BY c._ts DESC"
276241
parameters = [
277242
{"name": "@user_id", "value": self.user_id},
278243
{"name": "@team_id", "value": team_id},
@@ -328,7 +293,7 @@ async def get_team(self, team_id: str) -> Optional[TeamConfiguration]:
328293
teams = await self.query_items(query, parameters, TeamConfiguration)
329294
return teams[0] if teams else None
330295

331-
async def get_team_by_id(self, id: str) -> Optional[TeamConfiguration]:
296+
async def get_team_by_id(self, team_id: str) -> Optional[TeamConfiguration]:
332297
"""Retrieve a specific team configuration by its document id.
333298
334299
Args:
@@ -337,9 +302,9 @@ async def get_team_by_id(self, id: str) -> Optional[TeamConfiguration]:
337302
Returns:
338303
TeamConfiguration object or None if not found
339304
"""
340-
query = "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type"
305+
query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type"
341306
parameters = [
342-
{"name": "@id", "value": id},
307+
{"name": "@team_id", "value": team_id},
343308
{"name": "@data_type", "value": DataType.team_config},
344309
]
345310
teams = await self.query_items(query, parameters, TeamConfiguration)
@@ -383,27 +348,6 @@ async def delete_team(self, team_id: str) -> bool:
383348
logging.exception(f"Failed to delete team from Cosmos DB: {e}")
384349
return False
385350

386-
async def get_data_by_type_and_session_id(
387-
self, data_type: str, session_id: str
388-
) -> List[BaseDataModel]:
389-
"""Query the Cosmos DB for documents with the matching data_type, session_id and user_id."""
390-
await self._ensure_initialized()
391-
if self.container is None:
392-
return []
393-
394-
model_class = self.MODEL_CLASS_MAPPING.get(data_type, BaseDataModel)
395-
try:
396-
query = "SELECT * FROM c WHERE c.session_id=@session_id AND c.user_id=@user_id AND c.data_type=@data_type ORDER BY c._ts ASC"
397-
parameters = [
398-
{"name": "@session_id", "value": session_id},
399-
{"name": "@data_type", "value": data_type},
400-
{"name": "@user_id", "value": self.user_id},
401-
]
402-
return await self.query_items(query, parameters, model_class)
403-
except Exception as e:
404-
logging.exception(f"Failed to query data by type from Cosmos DB: {e}")
405-
return []
406-
407351
# Data Management Operations
408352
async def get_data_by_type(self, data_type: str) -> List[BaseDataModel]:
409353
"""Retrieve all data of a specific type."""
@@ -470,13 +414,89 @@ async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:
470414
teams = await self.query_items(query, parameters, UserCurrentTeam)
471415
return teams[0] if teams else None
472416

417+
418+
419+
async def delete_current_team(self, user_id: str) -> bool:
420+
"""Delete the current team for a user."""
421+
query = "SELECT c.id, c.session_id FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
422+
423+
params = [
424+
{"name": "@user_id", "value": user_id},
425+
{"name": "@data_type", "value": DataType.user_current_team},
426+
]
427+
items = self.container.query_items(query=query, parameters=params)
428+
print("Items to delete:", items)
429+
if items:
430+
async for doc in items:
431+
try:
432+
await self.container.delete_item(doc["id"], partition_key=doc["session_id"])
433+
except Exception as e:
434+
self.logger.warning("Failed deleting current team doc %s: %s", doc.get("id"), e)
435+
436+
return True
437+
473438
async def set_current_team(self, current_team: UserCurrentTeam) -> None:
474439
"""Set the current team for a user."""
475440
await self._ensure_initialized()
476441
await self.add_item(current_team)
477442

478-
479443
async def update_current_team(self, current_team: UserCurrentTeam) -> None:
480444
"""Update the current team for a user."""
481445
await self._ensure_initialized()
482446
await self.update_item(current_team)
447+
448+
async def delete_plan_by_plan_id(self, plan_id: str) -> bool:
449+
"""Delete a plan by its ID."""
450+
query = "SELECT c.id, c.session_id FROM c WHERE c.id=@plan_id "
451+
452+
params = [
453+
{"name": "@plan_id", "value": plan_id},
454+
]
455+
items = self.container.query_items(query=query, parameters=params)
456+
print("Items to delete planid:", items)
457+
if items:
458+
async for doc in items:
459+
try:
460+
await self.container.delete_item(doc["id"], partition_key=doc["session_id"])
461+
except Exception as e:
462+
self.logger.warning("Failed deleting current team doc %s: %s", doc.get("id"), e)
463+
464+
return True
465+
466+
async def add_mplan(self, mplan: messages.MPlan) -> None:
467+
"""Add a team configuration to the database."""
468+
await self.add_item(mplan)
469+
470+
async def update_mplan(self, mplan: messages.MPlan) -> None:
471+
"""Update a team configuration in the database."""
472+
await self.update_item(mplan)
473+
474+
475+
async def get_mplan(self, plan_id: str) -> Optional[messages.MPlan]:
476+
"""Retrieve a mplan configuration by mplan_id."""
477+
query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type"
478+
parameters = [
479+
{"name": "@plan_id", "value": plan_id},
480+
{"name": "@data_type", "value": DataType.m_plan},
481+
]
482+
results = await self.query_items(query, parameters, messages.MPlan)
483+
return results[0] if results else None
484+
485+
486+
async def add_agent_message(self, message: AgentMessageData) -> None:
487+
"""Add an agent message to the database."""
488+
await self.add_item(message)
489+
490+
async def update_agent_message(self, message: AgentMessageData) -> None:
491+
"""Update an agent message in the database."""
492+
await self.update_item(message)
493+
494+
async def get_agent_messages(self, plan_id: str) -> List[AgentMessageData]:
495+
"""Retrieve an agent message by message_id."""
496+
query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type ORDER BY c._ts ASC"
497+
parameters = [
498+
{"name": "@plan_id", "value": plan_id},
499+
{"name": "@data_type", "value": DataType.m_plan_message},
500+
]
501+
502+
return await self.query_items(query, parameters, AgentMessageData)

src/backend/common/database/database_base.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from abc import ABC, abstractmethod
44
from typing import Any, Dict, List, Optional, Type
5-
5+
import v3.models.messages as messages
66
from ..models.messages_kernel import (
7+
AgentMessageData,
78
BaseDataModel,
8-
Session,
99
Plan,
1010
Step,
1111
TeamConfiguration,
@@ -59,21 +59,6 @@ async def delete_item(self, item_id: str, partition_key: str) -> None:
5959
"""Delete an item from the database."""
6060
pass
6161

62-
# Session Operations
63-
@abstractmethod
64-
async def add_session(self, session: Session) -> None:
65-
"""Add a session to the database."""
66-
pass
67-
68-
@abstractmethod
69-
async def get_session(self, session_id: str) -> Optional[Session]:
70-
"""Retrieve a session by session_id."""
71-
pass
72-
73-
@abstractmethod
74-
async def get_all_sessions(self) -> List[Session]:
75-
"""Retrieve all sessions for the user."""
76-
pass
7762

7863
# Plan Operations
7964
@abstractmethod
@@ -86,11 +71,6 @@ async def update_plan(self, plan: Plan) -> None:
8671
"""Update a plan in the database."""
8772
pass
8873

89-
@abstractmethod
90-
async def get_plan_by_session(self, session_id: str) -> Optional[Plan]:
91-
"""Retrieve a plan by session_id."""
92-
pass
93-
9474
@abstractmethod
9575
async def get_plan_by_plan_id(self, plan_id: str) -> Optional[Plan]:
9676
"""Retrieve a plan by plan_id."""
@@ -118,11 +98,7 @@ async def get_all_plans_by_team_id_status(
11898
"""Retrieve all plans for a specific team."""
11999
pass
120100

121-
@abstractmethod
122-
async def get_data_by_type_and_session_id(
123-
self, data_type: str, session_id: str
124-
) -> List[BaseDataModel]:
125-
pass
101+
126102

127103
# Step Operations
128104
@abstractmethod
@@ -162,7 +138,7 @@ async def get_team(self, team_id: str) -> Optional[TeamConfiguration]:
162138
pass
163139

164140
@abstractmethod
165-
async def get_team_by_id(self, id: str) -> Optional[TeamConfiguration]:
141+
async def get_team_by_id(self, team_id: str) -> Optional[TeamConfiguration]:
166142
"""Retrieve a team configuration by internal id."""
167143
pass
168144

@@ -207,6 +183,11 @@ async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:
207183
"""Retrieve the current team for a user."""
208184
pass
209185

186+
@abstractmethod
187+
async def delete_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:
188+
"""Retrieve the current team for a user."""
189+
pass
190+
210191
@abstractmethod
211192
async def set_current_team(self, current_team: UserCurrentTeam) -> None:
212193
pass
@@ -215,3 +196,37 @@ async def set_current_team(self, current_team: UserCurrentTeam) -> None:
215196
async def update_current_team(self, current_team: UserCurrentTeam) -> None:
216197
"""Update the current team for a user."""
217198
pass
199+
200+
@abstractmethod
201+
async def delete_plan_by_plan_id(self, plan_id: str) -> bool:
202+
"""Retrieve the current team for a user."""
203+
pass
204+
205+
@abstractmethod
206+
async def add_mplan(self, mplan: messages.MPlan) -> None:
207+
"""Add a team configuration to the database."""
208+
pass
209+
210+
@abstractmethod
211+
async def update_mplan(self, mplan: messages.MPlan) -> None:
212+
"""Update a team configuration in the database."""
213+
pass
214+
215+
@abstractmethod
216+
async def get_mplan(self, plan_id: str) -> Optional[messages.MPlan]:
217+
"""Retrieve a mplan configuration by plan_id."""
218+
pass
219+
220+
@abstractmethod
221+
async def add_agent_message(self, message: AgentMessageData) -> None:
222+
pass
223+
224+
@abstractmethod
225+
async def update_agent_message(self, message: AgentMessageData) -> None:
226+
"""Update an agent message in the database."""
227+
pass
228+
229+
@abstractmethod
230+
async def get_agent_messages(self, plan_id: str) -> Optional[AgentMessageData]:
231+
"""Retrieve an agent message by message_id."""
232+
pass

0 commit comments

Comments
 (0)