Skip to content

Commit b5cc787

Browse files
committed
Add team-specific methods to CosmosMemoryContext
Introduces dedicated CRUD methods for TeamConfiguration objects in CosmosMemoryContext, including add, update, retrieve, and delete operations. Refactors JsonService to use these new methods for team configuration management, improving clarity and encapsulation. Adds a test script to verify the correct behavior of all team-specific methods.
1 parent 81a732e commit b5cc787

File tree

3 files changed

+316
-109
lines changed

3 files changed

+316
-109
lines changed

src/backend/context/cosmos_memory_kernel.py

Lines changed: 134 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717

1818
# Import the AppConfig instance
1919
from app_config import config
20-
from models.messages_kernel import BaseDataModel, Plan, Session, Step, AgentMessage
20+
from models.messages_kernel import (
21+
BaseDataModel,
22+
Plan,
23+
Session,
24+
Step,
25+
AgentMessage,
26+
TeamConfiguration,
27+
)
2128

2229

2330
# Add custom JSON encoder class for datetime objects
@@ -38,6 +45,7 @@ class CosmosMemoryContext(MemoryStoreBase):
3845
"plan": Plan,
3946
"step": Step,
4047
"agent_message": AgentMessage,
48+
"team_config": TeamConfiguration,
4149
# Messages are handled separately
4250
}
4351

@@ -168,65 +176,6 @@ async def get_item_by_id(
168176
logging.exception(f"Failed to retrieve item from Cosmos DB: {e}")
169177
return None
170178

171-
async def query_items_with_parameters(
172-
self, query: str, parameters: List[Dict[str, Any]], limit: int = 1000
173-
) -> List[Dict[str, Any]]:
174-
"""Query items from Cosmos DB with parameters and return raw dictionaries."""
175-
await self.ensure_initialized()
176-
177-
try:
178-
items = self._container.query_items(
179-
query=query,
180-
parameters=parameters,
181-
enable_cross_partition_query=True
182-
)
183-
result_list = []
184-
count = 0
185-
async for item in items:
186-
if count >= limit:
187-
break
188-
result_list.append(item)
189-
count += 1
190-
return result_list
191-
except Exception as e:
192-
logging.exception(f"Failed to query items from Cosmos DB: {e}")
193-
return []
194-
195-
async def get_async(self, key: str) -> Optional[Dict[str, Any]]:
196-
"""Get an item by its key/ID."""
197-
await self.ensure_initialized()
198-
199-
try:
200-
# Query by ID across all partitions since we don't know the partition key
201-
query = "SELECT * FROM c WHERE c.id=@id"
202-
parameters = [{"name": "@id", "value": key}]
203-
204-
items = self._container.query_items(
205-
query=query,
206-
parameters=parameters,
207-
enable_cross_partition_query=True
208-
)
209-
210-
async for item in items:
211-
return item
212-
return None
213-
except Exception as e:
214-
logging.exception(f"Failed to get item from Cosmos DB: {e}")
215-
return None
216-
217-
async def delete_async(self, key: str) -> None:
218-
"""Delete an item by its key/ID."""
219-
await self.ensure_initialized()
220-
221-
try:
222-
# First get the item to find its partition key
223-
item = await self.get_async(key)
224-
if item:
225-
partition_key = item.get("session_id", item.get("user_id", key))
226-
await self._container.delete_item(item=key, partition_key=partition_key)
227-
except Exception as e:
228-
logging.exception(f"Failed to delete item from Cosmos DB: {e}")
229-
230179
async def query_items(
231180
self,
232181
query: str,
@@ -400,6 +349,131 @@ async def get_agent_messages_by_session(
400349
messages = await self.query_items(query, parameters, AgentMessage)
401350
return messages
402351

352+
async def add_team(self, team: TeamConfiguration) -> None:
353+
"""Add a team configuration to Cosmos DB.
354+
355+
Args:
356+
team: The TeamConfiguration to add
357+
"""
358+
await self.add_item(team)
359+
360+
async def update_team(self, team: TeamConfiguration) -> None:
361+
"""Update an existing team configuration in Cosmos DB.
362+
363+
Args:
364+
team: The TeamConfiguration to update
365+
"""
366+
await self.update_item(team)
367+
368+
async def get_team(self, team_id: str) -> Optional[TeamConfiguration]:
369+
"""Retrieve a specific team configuration by team_id.
370+
371+
Args:
372+
team_id: The team_id of the team configuration to retrieve
373+
374+
Returns:
375+
TeamConfiguration object or None if not found
376+
"""
377+
query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type"
378+
parameters = [
379+
{"name": "@team_id", "value": team_id},
380+
{"name": "@data_type", "value": "team_config"},
381+
]
382+
teams = await self.query_items(query, parameters, TeamConfiguration)
383+
return teams[0] if teams else None
384+
385+
async def get_team_by_id(self, id: str) -> Optional[TeamConfiguration]:
386+
"""Retrieve a specific team configuration by its document id.
387+
388+
Args:
389+
id: The document id of the team configuration to retrieve
390+
391+
Returns:
392+
TeamConfiguration object or None if not found
393+
"""
394+
query = "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type"
395+
parameters = [
396+
{"name": "@id", "value": id},
397+
{"name": "@data_type", "value": "team_config"},
398+
]
399+
teams = await self.query_items(query, parameters, TeamConfiguration)
400+
return teams[0] if teams else None
401+
402+
async def get_all_teams_by_user(self, user_id: str) -> List[TeamConfiguration]:
403+
"""Retrieve all team configurations for a specific user.
404+
405+
Args:
406+
user_id: The user_id to get team configurations for
407+
408+
Returns:
409+
List of TeamConfiguration objects
410+
"""
411+
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type ORDER BY c.created DESC"
412+
parameters = [
413+
{"name": "@user_id", "value": user_id},
414+
{"name": "@data_type", "value": "team_config"},
415+
]
416+
teams = await self.query_items(query, parameters, TeamConfiguration)
417+
return teams
418+
419+
async def delete_team(self, team_id: str) -> bool:
420+
"""Delete a team configuration by team_id.
421+
422+
Args:
423+
team_id: The team_id of the team configuration to delete
424+
425+
Returns:
426+
True if team was found and deleted, False otherwise
427+
"""
428+
await self.ensure_initialized()
429+
430+
try:
431+
# First find the team to get its document id and partition key
432+
team = await self.get_team(team_id)
433+
if team:
434+
# Use the session_id as partition key, or fall back to user_id if no session_id
435+
partition_key = (
436+
team.session_id
437+
if hasattr(team, "session_id") and team.session_id
438+
else team.user_id
439+
)
440+
await self._container.delete_item(
441+
item=team.id, partition_key=partition_key
442+
)
443+
return True
444+
return False
445+
except Exception as e:
446+
logging.exception(f"Failed to delete team from Cosmos DB: {e}")
447+
return False
448+
449+
async def delete_team_by_id(self, id: str) -> bool:
450+
"""Delete a team configuration by its document id.
451+
452+
Args:
453+
id: The document id of the team configuration to delete
454+
455+
Returns:
456+
True if team was found and deleted, False otherwise
457+
"""
458+
await self.ensure_initialized()
459+
460+
try:
461+
# First find the team to get its partition key
462+
team = await self.get_team_by_id(id)
463+
if team:
464+
# Use the session_id as partition key, or fall back to user_id if no session_id
465+
partition_key = (
466+
team.session_id
467+
if hasattr(team, "session_id") and team.session_id
468+
else team.user_id
469+
)
470+
await self._container.delete_item(item=id, partition_key=partition_key)
471+
return True
472+
return False
473+
except Exception as e:
474+
logging.exception(f"Failed to delete team from Cosmos DB: {e}")
475+
return False
476+
403477
async def add_message(self, message: ChatMessageContent) -> None:
404478
"""Add a message to the memory and save to Cosmos DB."""
405479
await self.ensure_initialized()

src/backend/services/json_service.py

Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,8 @@ async def save_team_configuration(self, team_config: TeamConfiguration) -> str:
149149
The unique ID of the saved configuration
150150
"""
151151
try:
152-
# Convert to dictionary and add data_type for proper querying
153-
config_dict = team_config.model_dump()
154-
config_dict["data_type"] = "team_config"
155-
156-
# Use the cosmos memory context to save the team configuration
157-
await self.memory_context.upsert_async(
158-
collection_name=team_config.id,
159-
record=config_dict
160-
)
152+
# Use the specific add_team method from cosmos memory context
153+
await self.memory_context.add_team(team_config)
161154

162155
self.logger.info(
163156
"Successfully saved team configuration with ID: %s", team_config.id
@@ -182,22 +175,22 @@ async def get_team_configuration(
182175
TeamConfiguration object or None if not found
183176
"""
184177
try:
185-
# Get the specific configuration using cosmos memory context
186-
config_dict = await self.memory_context.get_async(config_id)
187-
188-
if config_dict is None:
178+
# Get the specific configuration using the team-specific method
179+
team_config = await self.memory_context.get_team_by_id(config_id)
180+
181+
if team_config is None:
189182
return None
190-
183+
191184
# Verify the configuration belongs to the user
192-
if config_dict.get("user_id") != user_id:
185+
if team_config.user_id != user_id:
193186
self.logger.warning(
194187
"Access denied: config %s does not belong to user %s",
195188
config_id,
196189
user_id,
197190
)
198191
return None
199192

200-
return TeamConfiguration.model_validate(config_dict)
193+
return team_config
201194

202195
except (KeyError, TypeError, ValueError) as e:
203196
self.logger.error("Error retrieving team configuration: %s", str(e))
@@ -216,28 +209,8 @@ async def get_all_team_configurations(
216209
List of TeamConfiguration objects
217210
"""
218211
try:
219-
# Query configurations using SQL with parameters through cosmos memory context
220-
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
221-
parameters = [
222-
{"name": "@user_id", "value": user_id},
223-
{"name": "@data_type", "value": "team_config"},
224-
]
225-
226-
configs = await self.memory_context.query_items_with_parameters(
227-
query, parameters, limit=1000
228-
)
229-
230-
team_configs = []
231-
for config_dict in configs:
232-
try:
233-
team_config = TeamConfiguration.model_validate(config_dict)
234-
team_configs.append(team_config)
235-
except (ValueError, TypeError) as e:
236-
self.logger.warning(
237-
"Failed to parse team configuration: %s", str(e)
238-
)
239-
continue
240-
212+
# Use the specific get_all_teams_by_user method
213+
team_configs = await self.memory_context.get_all_teams_by_user(user_id)
241214
return team_configs
242215

243216
except (KeyError, TypeError, ValueError) as e:
@@ -257,27 +230,32 @@ async def delete_team_configuration(self, config_id: str, user_id: str) -> bool:
257230
"""
258231
try:
259232
# First, verify the configuration exists and belongs to the user
260-
config_dict = await self.memory_context.get_async(config_id)
261-
262-
if config_dict is None:
233+
team_config = await self.memory_context.get_team_by_id(config_id)
234+
235+
if team_config is None:
263236
self.logger.warning(
264237
"Team configuration not found for deletion: %s", config_id
265238
)
266239
return False
267-
240+
268241
# Verify the configuration belongs to the user
269-
if config_dict.get("user_id") != user_id:
242+
if team_config.user_id != user_id:
270243
self.logger.warning(
271-
"Access denied: cannot delete config %s for user %s",
272-
config_id, user_id
244+
"Access denied: cannot delete config %s for user %s",
245+
config_id,
246+
user_id,
273247
)
274248
return False
275249

276-
# Delete the configuration using cosmos memory context
277-
await self.memory_context.delete_async(config_id)
250+
# Delete the configuration using the specific delete_team_by_id method
251+
success = await self.memory_context.delete_team_by_id(config_id)
252+
253+
if success:
254+
self.logger.info(
255+
"Successfully deleted team configuration: %s", config_id
256+
)
278257

279-
self.logger.info("Successfully deleted team configuration: %s", config_id)
280-
return True
258+
return success
281259

282260
except (KeyError, TypeError, ValueError) as e:
283261
self.logger.error("Error deleting team configuration: %s", str(e))

0 commit comments

Comments
 (0)