Skip to content

Commit 2d63bb2

Browse files
committed
Refactor session and team management in backend
Removed session-related methods and fields from database base and CosmosDB client, and updated data models to eliminate redundant session_id fields. Added and refactored current team deletion logic, and improved team selection handling to return the selected team. Updated API and service layers to use new team management methods and fixed related bugs.
1 parent d34ea29 commit 2d63bb2

File tree

5 files changed

+45
-68
lines changed

5 files changed

+45
-68
lines changed

src/backend/common/database/cosmosdb.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -190,29 +190,6 @@ async def delete_item(self, item_id: str, partition_key: str) -> None:
190190
self.logger.error("Failed to delete item from CosmosDB: %s", str(e))
191191
raise
192192

193-
# Session Operations
194-
async def add_session(self, session: Session) -> None:
195-
"""Add a session to CosmosDB."""
196-
await self.add_item(session)
197-
198-
async def get_session(self, session_id: str) -> Optional[Session]:
199-
"""Retrieve a session by session_id."""
200-
query = "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type"
201-
parameters = [
202-
{"name": "@id", "value": session_id},
203-
{"name": "@data_type", "value": DataType.session},
204-
]
205-
results = await self.query_items(query, parameters, Session)
206-
return results[0] if results else None
207-
208-
async def get_all_sessions(self) -> List[Session]:
209-
"""Retrieve all sessions for the user."""
210-
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
211-
parameters = [
212-
{"name": "@user_id", "value": self.user_id},
213-
{"name": "@data_type", "value": DataType.session},
214-
]
215-
return await self.query_items(query, parameters, Session)
216193

217194
# Plan Operations
218195
async def add_plan(self, plan: Plan) -> None:
@@ -470,12 +447,32 @@ async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:
470447
teams = await self.query_items(query, parameters, UserCurrentTeam)
471448
return teams[0] if teams else None
472449

450+
451+
452+
async def delete_current_team(self, user_id: str) -> bool:
453+
"""Delete the current team for a user."""
454+
query = "SELECT c.id, c.session_id FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
455+
456+
params = [
457+
{"name": "@user_id", "value": user_id},
458+
{"name": "@data_type", "value": DataType.user_current_team},
459+
]
460+
items = self.container.query_items(query=query, parameters=params)
461+
print("Items to delete:", items)
462+
if items:
463+
async for doc in items:
464+
try:
465+
await self.container.delete_item(doc["id"], partition_key=doc["session_id"])
466+
except Exception as e:
467+
self.logger.warning("Failed deleting current team doc %s: %s", doc.get("id"), e)
468+
469+
return True
470+
473471
async def set_current_team(self, current_team: UserCurrentTeam) -> None:
474472
"""Set the current team for a user."""
475473
await self._ensure_initialized()
476474
await self.add_item(current_team)
477475

478-
479476
async def update_current_team(self, current_team: UserCurrentTeam) -> None:
480477
"""Update the current team for a user."""
481478
await self._ensure_initialized()

src/backend/common/database/database_base.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from ..models.messages_kernel import (
77
BaseDataModel,
8-
Session,
98
Plan,
109
Step,
1110
TeamConfiguration,
@@ -59,21 +58,6 @@ async def delete_item(self, item_id: str, partition_key: str) -> None:
5958
"""Delete an item from the database."""
6059
pass
6160

62-
# Session Operations
63-
@abstractmethod
64-
async def add_session(self, session: Session) -> None:
65-
"""Add a session to the database."""
66-
pass
67-
68-
@abstractmethod
69-
async def get_session(self, session_id: str) -> Optional[Session]:
70-
"""Retrieve a session by session_id."""
71-
pass
72-
73-
@abstractmethod
74-
async def get_all_sessions(self) -> List[Session]:
75-
"""Retrieve all sessions for the user."""
76-
pass
7761

7862
# Plan Operations
7963
@abstractmethod
@@ -86,11 +70,6 @@ async def update_plan(self, plan: Plan) -> None:
8670
"""Update a plan in the database."""
8771
pass
8872

89-
@abstractmethod
90-
async def get_plan_by_session(self, session_id: str) -> Optional[Plan]:
91-
"""Retrieve a plan by session_id."""
92-
pass
93-
9473
@abstractmethod
9574
async def get_plan_by_plan_id(self, plan_id: str) -> Optional[Plan]:
9675
"""Retrieve a plan by plan_id."""
@@ -207,6 +186,11 @@ async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:
207186
"""Retrieve the current team for a user."""
208187
pass
209188

189+
@abstractmethod
190+
async def delete_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:
191+
"""Retrieve the current team for a user."""
192+
pass
193+
210194
@abstractmethod
211195
async def set_current_team(self, current_team: UserCurrentTeam) -> None:
212196
pass

src/backend/common/models/messages_kernel.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class BaseDataModel(KernelBaseModel):
7878
"""Base data model with common fields."""
7979

8080
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
81+
session_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
8182
timestamp: Optional[datetime] = Field(
8283
default_factory=lambda: datetime.now(timezone.utc)
8384
)
@@ -87,8 +88,6 @@ class AgentMessage(BaseDataModel):
8788
"""Base class for messages sent between agents."""
8889

8990
data_type: Literal[DataType.agent_message] = Field(DataType.agent_message, Literal=True)
90-
session_id: str
91-
user_id: str
9291
plan_id: str
9392
content: str
9493
source: str
@@ -118,7 +117,6 @@ class Plan(BaseDataModel):
118117

119118
data_type: Literal[DataType.plan] = Field(DataType.plan, Literal=True)
120119
plan_id: str
121-
session_id: str
122120
user_id: str
123121
initial_goal: str
124122
overall_status: PlanStatus = PlanStatus.in_progress
@@ -135,7 +133,6 @@ class Step(BaseDataModel):
135133

136134
data_type: Literal[DataType.step] = Field(DataType.step, Literal=True)
137135
plan_id: str
138-
session_id: str # Partition key
139136
user_id: str
140137
action: str
141138
agent: AgentType
@@ -146,7 +143,7 @@ class Step(BaseDataModel):
146143
updated_action: Optional[str] = None
147144

148145

149-
class TeamSelectionRequest(KernelBaseModel):
146+
class TeamSelectionRequest(BaseDataModel):
150147
"""Request model for team selection."""
151148

152149
team_id: str

src/backend/v3/api/router.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,12 @@ async def init_team(
117117
team_service = TeamService(memory_store)
118118
user_current_team = await memory_store.get_current_team(user_id=user_id)
119119
if not user_current_team:
120-
await team_service.handle_team_selection(
120+
print("User has no current team, setting to default:", init_team_id)
121+
user_current_team = await team_service.handle_team_selection(
121122
user_id=user_id, team_id=init_team_id
122123
)
124+
if user_current_team:
125+
init_team_id = user_current_team.team_id
123126
else:
124127
init_team_id = user_current_team.team_id
125128
# Verify the team exists and user has access to it
@@ -896,7 +899,7 @@ async def select_team(selection: TeamSelectionRequest, request: Request):
896899
team_configuration = await team_service.get_team_configuration(
897900
selection.team_id, user_id
898901
)
899-
if team_config is None:
902+
if team_configuration is None: # ensure that id is valid
900903
raise HTTPException(
901904
status_code=404,
902905
detail=f"Team configuration '{selection.team_id}' not found or access denied",

src/backend/v3/common/services/team_service.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -235,15 +235,15 @@ async def delete_user_current_team(self, user_id: str) -> bool:
235235
True if successful, False otherwise
236236
"""
237237
try:
238-
await self.memory_context.delete_user_current_team(user_id)
238+
await self.memory_context.delete_current_team(user_id)
239239
self.logger.info("Successfully deleted current team for user %s", user_id)
240240
return True
241241

242242
except Exception as e:
243243
self.logger.error("Error deleting current team: %s", str(e))
244244
return False
245245

246-
async def handle_team_selection(self, user_id: str, team_id: str) -> bool:
246+
async def handle_team_selection(self, user_id: str, team_id: str) -> UserCurrentTeam:
247247
"""
248248
Set a default team for a user.
249249
@@ -254,25 +254,21 @@ async def handle_team_selection(self, user_id: str, team_id: str) -> bool:
254254
Returns:
255255
True if successful, False otherwise
256256
"""
257+
print("Handling team selection for user:", user_id, "team:", team_id)
257258
try:
258-
current_team = await self.memory_context.get_current_team(user_id)
259-
260-
if current_team is None:
261-
current_team = UserCurrentTeam(user_id=user_id, team_id=team_id)
262-
await self.memory_context.set_current_team(current_team)
263-
return True
264-
else:
265-
current_team.team_id = team_id
266-
await self.memory_context.update_current_team(current_team)
267-
return True
259+
await self.memory_context.delete_current_team(user_id)
260+
current_team = UserCurrentTeam(
261+
user_id=user_id,
262+
team_id=team_id,
263+
)
264+
await self.memory_context.set_current_team(current_team)
265+
return current_team
268266

269267
except Exception as e:
270268
self.logger.error("Error setting default team: %s", str(e))
271-
return False
269+
return None
272270

273-
async def get_all_team_configurations(
274-
self
275-
) -> List[TeamConfiguration]:
271+
async def get_all_team_configurations(self) -> List[TeamConfiguration]:
276272
"""
277273
Retrieve all team configurations for a user.
278274

0 commit comments

Comments
 (0)