Skip to content

Commit 377365f

Browse files
committed
Refactor agent message handling and plan retrieval
Updated database interfaces and implementations to use AgentMessageData instead of AgentMessageResponse for agent message operations. Added ordering to CosmosDB queries for plans and messages. Modified API route for plan retrieval to return plan, team, messages, and m_plan as structured objects instead of a list, improving response consistency and context.
1 parent f141f5d commit 377365f

File tree

3 files changed

+39
-33
lines changed

3 files changed

+39
-33
lines changed

src/backend/common/database/cosmosdb.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ async def get_all_plans_by_team_id(self, team_id: str) -> List[Plan]:
237237

238238
async def get_all_plans_by_team_id_status(self, team_id: str, status: str) -> List[Plan]:
239239
"""Retrieve all plans for a specific team."""
240-
query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type and c.user_id=@user_id and c.overall_status=@status"
240+
query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type and c.user_id=@user_id and c.overall_status=@status ORDER BY c._ts DESC"
241241
parameters = [
242242
{"name": "@user_id", "value": self.user_id},
243243
{"name": "@team_id", "value": team_id},
@@ -474,7 +474,7 @@ async def update_mplan(self, mplan: messages.MPlan) -> None:
474474

475475
async def get_mplan(self, plan_id: str) -> Optional[messages.MPlan]:
476476
"""Retrieve a mplan configuration by mplan_id."""
477-
query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type"
477+
query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type ORDER BY c._ts ASC"
478478
parameters = [
479479
{"name": "@plan_id", "value": plan_id},
480480
{"name": "@data_type", "value": DataType.m_plan},
@@ -483,20 +483,20 @@ async def get_mplan(self, plan_id: str) -> Optional[messages.MPlan]:
483483
return results[0] if results else None
484484

485485

486-
async def add_agent_message(self, message: messages.AgentMessageResponse) -> None:
486+
async def add_agent_message(self, message: AgentMessageData) -> None:
487487
"""Add an agent message to the database."""
488488
await self.add_item(message)
489489

490-
async def update_agent_message(self, message: messages.AgentMessageResponse) -> None:
490+
async def update_agent_message(self, message: AgentMessageData) -> None:
491491
"""Update an agent message in the database."""
492492
await self.update_item(message)
493493

494-
async def get_agent_messages(self, plan_id: str) -> List[messages.AgentMessageResponse]:
494+
async def get_agent_messages(self, plan_id: str) -> List[AgentMessageData]:
495495
"""Retrieve an agent message by message_id."""
496-
query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type"
496+
query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type ORDER BY c._ts ASC"
497497
parameters = [
498498
{"name": "@plan_id", "value": plan_id},
499499
{"name": "@data_type", "value": DataType.m_plan_message},
500500
]
501501

502-
return await self.query_items(query, parameters, messages.AgentMessageResponse)
502+
return await self.query_items(query, parameters, AgentMessageData)

src/backend/common/database/database_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,11 @@ async def add_agent_message(self, message: AgentMessageData) -> None:
222222
pass
223223

224224
@abstractmethod
225-
async def update_agent_message(self, message: AgentMessageResponse) -> None:
225+
async def update_agent_message(self, message: AgentMessageData) -> None:
226226
"""Update an agent message in the database."""
227227
pass
228228

229229
@abstractmethod
230-
async def get_agent_messages(self, plan_id: str) -> Optional[AgentMessageResponse]:
230+
async def get_agent_messages(self, plan_id: str) -> Optional[AgentMessageData]:
231231
"""Retrieve an agent message by message_id."""
232232
pass

src/backend/v3/api/router.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,8 @@ async def user_clarification(
559559

560560
try:
561561
result = await PlanService.handle_human_clarification(
562-
human_feedback, user_id)
562+
human_feedback, user_id
563+
)
563564
print("Human clarification processed:", result)
564565
except ValueError as ve:
565566
print(f"ValueError processing human clarification: {ve}")
@@ -573,7 +574,9 @@ async def user_clarification(
573574
"user_id": user_id,
574575
},
575576
)
576-
return {"status": "clarification recorded",}
577+
return {
578+
"status": "clarification recorded",
579+
}
577580
else:
578581
logging.warning(
579582
f"No orchestration or plan found for request_id: {human_feedback.request_id}"
@@ -582,6 +585,7 @@ async def user_clarification(
582585
status_code=404, detail="No active plan found for clarification"
583586
)
584587

588+
585589
@app_v3.post("/agent_message")
586590
async def agent_message_user(
587591
agent_message: messages.AgentMessageResponse, request: Request
@@ -642,14 +646,14 @@ async def agent_message_user(
642646
# Set the approval in the orchestration config
643647

644648
try:
645-
649+
646650
result = await PlanService.handle_agent_messages(agent_message, user_id)
647651
print("Agent message processed:", result)
648652
except ValueError as ve:
649653
print(f"ValueError processing agent message: {ve}")
650654
except Exception as e:
651655
print(f"Error processing agent message: {e}")
652-
656+
653657
track_event_if_configured(
654658
"AgentMessageReceived",
655659
{
@@ -658,7 +662,9 @@ async def agent_message_user(
658662
"user_id": user_id,
659663
},
660664
)
661-
return {"status": "message recorded",}
665+
return {
666+
"status": "message recorded",
667+
}
662668

663669

664670
@app_v3.post("/upload_team_config")
@@ -726,19 +732,19 @@ async def upload_team_config(
726732
track_event_if_configured(
727733
"Team configuration RAI validation failed",
728734
{
729-
"status": "failed",
730-
"user_id": user_id,
731-
"filename": file.filename,
732-
"reason": rai_error,
733-
},
734-
)
735-
735+
"status": "failed",
736+
"user_id": user_id,
737+
"filename": file.filename,
738+
"reason": rai_error,
739+
},
740+
)
741+
736742
raise HTTPException(status_code=400, detail=rai_error)
737743

738744
track_event_if_configured(
739-
"Team configuration RAI validation passed",
740-
{"status": "passed", "user_id": user_id, "filename": file.filename},
741-
)
745+
"Team configuration RAI validation passed",
746+
{"status": "passed", "user_id": user_id, "filename": file.filename},
747+
)
742748
# Initialize memory store and service
743749
memory_store = await DatabaseFactory.get_database(user_id=user_id)
744750
team_service = TeamService(memory_store)
@@ -1340,16 +1346,16 @@ async def get_plan_by_id(request: Request, plan_id: str):
13401346
raise HTTPException(status_code=404, detail="Plan not found")
13411347

13421348
# Use get_steps_by_plan to match the original implementation
1343-
steps = await memory_store.get_steps_by_plan(plan_id=plan.id)
1344-
messages = []
13451349

1346-
plan_with_steps = PlanWithSteps(**plan.model_dump(), steps=steps)
1347-
plan_with_steps.update_step_counts()
1348-
1349-
# Format dates in messages according to locale
1350-
formatted_messages = []
1351-
1352-
return [plan_with_steps, formatted_messages]
1350+
team = await memory_store.get_team_by_id(team_id=plan.team_id)
1351+
messages = await memory_store.get_agent_messages(plan_id=plan.plan_id)
1352+
m_plan = await memory_store.get_m_plan_by_plan_id(plan_id=plan.plan_id)
1353+
return {
1354+
"plan": plan.model_dump(),
1355+
"team": team.model_dump() if team else None,
1356+
"messages": [msg.model_dump() for msg in messages],
1357+
"m_plan": m_plan.model_dump() if m_plan else None,
1358+
}
13531359
else:
13541360
track_event_if_configured(
13551361
"GetPlanId", {"status_code": 400, "detail": "no plan id"}

0 commit comments

Comments
 (0)