Skip to content

Commit c8e0405

Browse files
committed
Refactor data_type usage and add PlanService
Replaces string literals with DataType enum for data_type fields and queries in models and CosmosDB client for consistency and type safety. Adds PlanService to encapsulate plan approval logic and updates router to use this service, improving separation of concerns and error handling.
1 parent 86f6a8f commit c8e0405

File tree

4 files changed

+86
-35
lines changed

4 files changed

+86
-35
lines changed

src/backend/common/database/cosmosdb.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ class CosmosDBClient(DatabaseBase):
3838
"""CosmosDB implementation of the database interface."""
3939

4040
MODEL_CLASS_MAPPING = {
41-
"session": Session,
42-
"plan": Plan,
43-
"step": Step,
44-
"agent_message": AgentMessage,
45-
"team_config": TeamConfiguration,
46-
"user_current_team": UserCurrentTeam,
41+
DataType.session: Session,
42+
DataType.plan: Plan,
43+
DataType.step: Step,
44+
DataType.agent_message: AgentMessage,
45+
DataType.team_config: TeamConfiguration,
46+
DataType.user_current_team: UserCurrentTeam,
4747
}
4848

4949
def __init__(
@@ -200,7 +200,7 @@ async def get_session(self, session_id: str) -> Optional[Session]:
200200
query = "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type"
201201
parameters = [
202202
{"name": "@id", "value": session_id},
203-
{"name": "@data_type", "value": "session"},
203+
{"name": "@data_type", "value": DataType.session},
204204
]
205205
results = await self.query_items(query, parameters, Session)
206206
return results[0] if results else None
@@ -210,7 +210,7 @@ async def get_all_sessions(self) -> List[Session]:
210210
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
211211
parameters = [
212212
{"name": "@user_id", "value": self.user_id},
213-
{"name": "@data_type", "value": "session"},
213+
{"name": "@data_type", "value": DataType.session},
214214
]
215215
return await self.query_items(query, parameters, Session)
216216

@@ -230,7 +230,7 @@ async def get_plan_by_session(self, session_id: str) -> Optional[Plan]:
230230
)
231231
parameters = [
232232
{"name": "@session_id", "value": session_id},
233-
{"name": "@data_type", "value": "plan"},
233+
{"name": "@data_type", "value": DataType.plan},
234234
]
235235
results = await self.query_items(query, parameters, Plan)
236236
return results[0] if results else None
@@ -240,7 +240,7 @@ async def get_plan_by_plan_id(self, plan_id: str) -> Optional[Plan]:
240240
query = "SELECT * FROM c WHERE c.id=@plan_id AND c.data_type=@data_type"
241241
parameters = [
242242
{"name": "@plan_id", "value": plan_id},
243-
{"name": "@data_type", "value": "plan"},
243+
{"name": "@data_type", "value": DataType.plan},
244244
{"name": "@user_id", "value": self.user_id},
245245
]
246246
results = await self.query_items(query, parameters, Plan)
@@ -255,7 +255,7 @@ async def get_all_plans(self) -> List[Plan]:
255255
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
256256
parameters = [
257257
{"name": "@user_id", "value": self.user_id},
258-
{"name": "@data_type", "value": "plan"},
258+
{"name": "@data_type", "value": DataType.plan},
259259
]
260260
return await self.query_items(query, parameters, Plan)
261261

@@ -265,7 +265,7 @@ async def get_all_plans_by_team_id(self, team_id: str) -> List[Plan]:
265265
parameters = [
266266
{"name": "@user_id", "value": self.user_id},
267267
{"name": "@team_id", "value": team_id},
268-
{"name": "@data_type", "value": "plan"},
268+
{"name": "@data_type", "value": DataType.plan},
269269
]
270270
return await self.query_items(query, parameters, Plan)
271271

@@ -283,7 +283,7 @@ async def get_steps_by_plan(self, plan_id: str) -> List[Step]:
283283
query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type ORDER BY c.timestamp"
284284
parameters = [
285285
{"name": "@plan_id", "value": plan_id},
286-
{"name": "@data_type", "value": "step"},
286+
{"name": "@data_type", "value": DataType.step},
287287
]
288288
return await self.query_items(query, parameters, Step)
289289

@@ -293,7 +293,7 @@ async def get_step(self, step_id: str, session_id: str) -> Optional[Step]:
293293
parameters = [
294294
{"name": "@step_id", "value": step_id},
295295
{"name": "@session_id", "value": session_id},
296-
{"name": "@data_type", "value": "step"},
296+
{"name": "@data_type", "value": DataType.step},
297297
]
298298
results = await self.query_items(query, parameters, Step)
299299
return results[0] if results else None
@@ -312,7 +312,7 @@ async def get_team(self, team_id: str) -> Optional[TeamConfiguration]:
312312
query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type"
313313
parameters = [
314314
{"name": "@team_id", "value": team_id},
315-
{"name": "@data_type", "value": "team_config"},
315+
{"name": "@data_type", "value": DataType.team_config},
316316
]
317317
teams = await self.query_items(query, parameters, TeamConfiguration)
318318
return teams[0] if teams else None
@@ -329,7 +329,7 @@ async def get_team_by_id(self, id: str) -> Optional[TeamConfiguration]:
329329
query = "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type"
330330
parameters = [
331331
{"name": "@id", "value": id},
332-
{"name": "@data_type", "value": "team_config"},
332+
{"name": "@data_type", "value": DataType.team_config},
333333
]
334334
teams = await self.query_items(query, parameters, TeamConfiguration)
335335
return teams[0] if teams else None
@@ -346,7 +346,7 @@ async def get_all_teams_by_user(self, user_id: str) -> List[TeamConfiguration]:
346346
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type ORDER BY c.created DESC"
347347
parameters = [
348348
{"name": "@user_id", "value": user_id},
349-
{"name": "@data_type", "value": "team_config"},
349+
{"name": "@data_type", "value": DataType.team_config},
350350
]
351351
teams = await self.query_items(query, parameters, TeamConfiguration)
352352
return teams
@@ -452,7 +452,7 @@ async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:
452452

453453
query = "SELECT * FROM c WHERE c.data_type=@data_type AND c.user_id=@user_id"
454454
parameters = [
455-
{"name": "@data_type", "value": "user_current_team"},
455+
{"name": "@data_type", "value": DataType.user_current_team},
456456
{"name": "@user_id", "value": user_id},
457457
]
458458

src/backend/common/models/messages_kernel.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ class DataType(str, Enum):
1212
session = "session"
1313
plan = "plan"
1414
step = "step"
15-
message = "agent_message"
16-
team = "team_config"
15+
agent_message = "agent_message"
16+
team_config = "team_config"
1717
user_current_team = "user_current_team"
1818
m_plan = "m_plan"
1919
m_plan_step = "m_plan_step"
@@ -84,7 +84,7 @@ class BaseDataModel(KernelBaseModel):
8484
class AgentMessage(BaseDataModel):
8585
"""Base class for messages sent between agents."""
8686

87-
data_type: Literal["agent_message"] = Field("agent_message", Literal=True)
87+
data_type: Literal[DataType.agent_message] = Field(DataType.agent_message, Literal=True)
8888
session_id: str
8989
user_id: str
9090
plan_id: str
@@ -96,7 +96,7 @@ class AgentMessage(BaseDataModel):
9696
class Session(BaseDataModel):
9797
"""Represents a user session."""
9898

99-
data_type: Literal["session"] = Field("session", Literal=True)
99+
data_type: Literal[DataType.session] = Field(DataType.session, Literal=True)
100100
user_id: str
101101
current_status: str
102102
message_to_user: Optional[str] = None
@@ -105,15 +105,15 @@ class Session(BaseDataModel):
105105
class UserCurrentTeam(BaseDataModel):
106106
"""Represents the current team of a user."""
107107

108-
data_type: Literal["user_current_team"] = Field("user_current_team", Literal=True)
108+
data_type: Literal[DataType.user_current_team] = Field(DataType.user_current_team, Literal=True)
109109
user_id: str
110110
team_id: str
111111

112112

113113
class Plan(BaseDataModel):
114114
"""Represents a plan containing multiple steps."""
115115

116-
data_type: Literal["plan"] = Field("plan", Literal=True)
116+
data_type: Literal[DataType.plan] = Field(DataType.plan, Literal=True)
117117
plan_id: str
118118
session_id: str
119119
user_id: str
@@ -129,7 +129,7 @@ class Plan(BaseDataModel):
129129
class Step(BaseDataModel):
130130
"""Represents an individual step (task) within a plan."""
131131

132-
data_type: Literal["step"] = Field("step", Literal=True)
132+
data_type: Literal[DataType.step] = Field(DataType.step, Literal=True)
133133
plan_id: str
134134
session_id: str # Partition key
135135
user_id: str
@@ -181,7 +181,7 @@ class TeamConfiguration(BaseDataModel):
181181
"""Represents a team configuration stored in the database."""
182182

183183
team_id: str
184-
data_type: Literal["team_config"] = Field("team_config", Literal=True)
184+
data_type: Literal[DataType.team_config] = Field(DataType.team_config, Literal=True)
185185
session_id: str # Partition key
186186
name: str
187187
status: str

src/backend/v3/api/router.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from common.utils.utils_date import format_dates_in_messages
99
from common.config.app_config import config
10+
from v3.common.services.plan_service import PlanService
1011
import v3.models.messages as messages
1112
from auth.auth_utils import get_authenticated_user_details
1213
from common.database.database_factory import DatabaseFactory
@@ -380,16 +381,12 @@ async def plan_approval(
380381
# orchestration_config.plans[human_feedback.m_plan_id],
381382
# )
382383
try:
383-
plan = orchestration_config.plans[human_feedback.m_plan_id]
384-
if hasattr(plan, "plan_id"):
385-
print(
386-
"Updated orchestration config:",
387-
orchestration_config.plans[human_feedback.m_plan_id],
388-
)
389-
plan.plan_id = human_feedback.plan_id
390-
orchestration_config.plans[human_feedback.m_plan_id] = plan
384+
result = await PlanService.handle_plan_approval(human_feedback, user_id)
385+
print("Plan approval processed:", result)
386+
except ValueError as ve:
387+
print(f"ValueError processing plan approval: {ve}")
391388
except Exception as e:
392-
print(f"Error processing plan approval: {e}")
389+
print(f"Error processing plan approval: {e}")
393390
track_event_if_configured(
394391
"PlanApprovalReceived",
395392
{
@@ -400,6 +397,7 @@ async def plan_approval(
400397
"feedback": human_feedback.feedback,
401398
},
402399
)
400+
403401
return {"status": "approval recorded"}
404402
else:
405403
logging.warning(
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import logging
2+
from typing import Dict, Any, Optional
3+
from common.database.database_factory import DatabaseFactory
4+
from common.database.database_base import DatabaseBase
5+
import v3.models.messages as messages
6+
from v3.config.settings import orchestration_config
7+
from common.utils.event_utils import track_event_if_configured
8+
9+
logger = logging.getLogger(__name__)
10+
11+
class PlanService:
12+
13+
14+
@staticmethod
15+
async def handle_plan_approval(human_feedback: messages.PlanApprovalResponse, user_id: str) -> bool:
16+
"""
17+
Process a PlanApprovalResponse coming from the client.
18+
19+
Args:
20+
feedback: messages.PlanApprovalResponse (contains m_plan_id, plan_id, approved, feedback)
21+
user_id: authenticated user id
22+
23+
Returns:
24+
dict with status and metadata
25+
26+
Raises:
27+
ValueError on invalid state
28+
"""
29+
if orchestration_config is None:
30+
return False
31+
try:
32+
mplan = orchestration_config.plans[human_feedback.m_plan_id]
33+
if hasattr(mplan, "plan_id"):
34+
print(
35+
"Updated orchestration config:",
36+
orchestration_config.plans[human_feedback.m_plan_id],
37+
)
38+
mplan.plan_id = human_feedback.plan_id
39+
orchestration_config.plans[human_feedback.m_plan_id] = mplan
40+
memory_store = await DatabaseFactory.get_database(user_id=user_id)
41+
plan = await memory_store.get_plan(human_feedback.plan_id)
42+
if plan:
43+
print("Retrieved plan from memory store:", plan)
44+
45+
46+
else:
47+
print("Plan not found in memory store.")
48+
return False
49+
50+
except Exception as e:
51+
print(f"Error processing plan approval: {e}")
52+
return False
53+
return True

0 commit comments

Comments
 (0)