Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/backend/common/database/cosmosdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ class CosmosDBClient(DatabaseBase):
"""CosmosDB implementation of the database interface."""

MODEL_CLASS_MAPPING = {
"session": Session,
"plan": Plan,
"step": Step,
"agent_message": AgentMessage,
"team_config": TeamConfiguration,
"user_current_team": UserCurrentTeam,
DataType.session: Session,
DataType.plan: Plan,
DataType.step: Step,
DataType.agent_message: AgentMessage,
DataType.team_config: TeamConfiguration,
DataType.user_current_team: UserCurrentTeam,
}

def __init__(
Expand Down Expand Up @@ -200,7 +200,7 @@ async def get_session(self, session_id: str) -> Optional[Session]:
query = "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type"
parameters = [
{"name": "@id", "value": session_id},
{"name": "@data_type", "value": "session"},
{"name": "@data_type", "value": DataType.session},
]
results = await self.query_items(query, parameters, Session)
return results[0] if results else None
Expand All @@ -210,7 +210,7 @@ async def get_all_sessions(self) -> List[Session]:
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
parameters = [
{"name": "@user_id", "value": self.user_id},
{"name": "@data_type", "value": "session"},
{"name": "@data_type", "value": DataType.session},
]
return await self.query_items(query, parameters, Session)

Expand All @@ -230,7 +230,7 @@ async def get_plan_by_session(self, session_id: str) -> Optional[Plan]:
)
parameters = [
{"name": "@session_id", "value": session_id},
{"name": "@data_type", "value": "plan"},
{"name": "@data_type", "value": DataType.plan},
]
results = await self.query_items(query, parameters, Plan)
return results[0] if results else None
Expand All @@ -240,7 +240,7 @@ async def get_plan_by_plan_id(self, plan_id: str) -> Optional[Plan]:
query = "SELECT * FROM c WHERE c.id=@plan_id AND c.data_type=@data_type"
parameters = [
{"name": "@plan_id", "value": plan_id},
{"name": "@data_type", "value": "plan"},
{"name": "@data_type", "value": DataType.plan},
{"name": "@user_id", "value": self.user_id},
]
results = await self.query_items(query, parameters, Plan)
Expand All @@ -255,7 +255,7 @@ async def get_all_plans(self) -> List[Plan]:
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
parameters = [
{"name": "@user_id", "value": self.user_id},
{"name": "@data_type", "value": "plan"},
{"name": "@data_type", "value": DataType.plan},
]
return await self.query_items(query, parameters, Plan)

Expand All @@ -265,7 +265,7 @@ async def get_all_plans_by_team_id(self, team_id: str) -> List[Plan]:
parameters = [
{"name": "@user_id", "value": self.user_id},
{"name": "@team_id", "value": team_id},
{"name": "@data_type", "value": "plan"},
{"name": "@data_type", "value": DataType.plan},
]
return await self.query_items(query, parameters, Plan)

Expand All @@ -283,7 +283,7 @@ async def get_steps_by_plan(self, plan_id: str) -> List[Step]:
query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type ORDER BY c.timestamp"
parameters = [
{"name": "@plan_id", "value": plan_id},
{"name": "@data_type", "value": "step"},
{"name": "@data_type", "value": DataType.step},
]
return await self.query_items(query, parameters, Step)

Expand All @@ -293,7 +293,7 @@ async def get_step(self, step_id: str, session_id: str) -> Optional[Step]:
parameters = [
{"name": "@step_id", "value": step_id},
{"name": "@session_id", "value": session_id},
{"name": "@data_type", "value": "step"},
{"name": "@data_type", "value": DataType.step},
]
results = await self.query_items(query, parameters, Step)
return results[0] if results else None
Expand All @@ -312,7 +312,7 @@ async def get_team(self, team_id: str) -> Optional[TeamConfiguration]:
query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type"
parameters = [
{"name": "@team_id", "value": team_id},
{"name": "@data_type", "value": "team_config"},
{"name": "@data_type", "value": DataType.team_config},
]
teams = await self.query_items(query, parameters, TeamConfiguration)
return teams[0] if teams else None
Expand All @@ -329,7 +329,7 @@ async def get_team_by_id(self, id: str) -> Optional[TeamConfiguration]:
query = "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type"
parameters = [
{"name": "@id", "value": id},
{"name": "@data_type", "value": "team_config"},
{"name": "@data_type", "value": DataType.team_config},
]
teams = await self.query_items(query, parameters, TeamConfiguration)
return teams[0] if teams else None
Expand All @@ -346,7 +346,7 @@ async def get_all_teams_by_user(self, user_id: str) -> List[TeamConfiguration]:
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type ORDER BY c.created DESC"
parameters = [
{"name": "@user_id", "value": user_id},
{"name": "@data_type", "value": "team_config"},
{"name": "@data_type", "value": DataType.team_config},
]
teams = await self.query_items(query, parameters, TeamConfiguration)
return teams
Expand Down Expand Up @@ -452,7 +452,7 @@ async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:

query = "SELECT * FROM c WHERE c.data_type=@data_type AND c.user_id=@user_id"
parameters = [
{"name": "@data_type", "value": "user_current_team"},
{"name": "@data_type", "value": DataType.user_current_team},
{"name": "@user_id", "value": user_id},
]

Expand Down
22 changes: 12 additions & 10 deletions src/backend/common/models/messages_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class DataType(str, Enum):
session = "session"
plan = "plan"
step = "step"
message = "agent_message"
team = "team_config"
agent_message = "agent_message"
team_config = "team_config"
user_current_team = "user_current_team"
m_plan = "m_plan"
m_plan_step = "m_plan_step"
Expand Down Expand Up @@ -84,7 +84,7 @@ class BaseDataModel(KernelBaseModel):
class AgentMessage(BaseDataModel):
"""Base class for messages sent between agents."""

data_type: Literal["agent_message"] = Field("agent_message", Literal=True)
data_type: Literal[DataType.agent_message] = Field(DataType.agent_message, Literal=True)
session_id: str
user_id: str
plan_id: str
Expand All @@ -96,7 +96,7 @@ class AgentMessage(BaseDataModel):
class Session(BaseDataModel):
"""Represents a user session."""

data_type: Literal["session"] = Field("session", Literal=True)
data_type: Literal[DataType.session] = Field(DataType.session, Literal=True)
user_id: str
current_status: str
message_to_user: Optional[str] = None
Expand All @@ -105,15 +105,15 @@ class Session(BaseDataModel):
class UserCurrentTeam(BaseDataModel):
"""Represents the current team of a user."""

data_type: Literal["user_current_team"] = Field("user_current_team", Literal=True)
data_type: Literal[DataType.user_current_team] = Field(DataType.user_current_team, Literal=True)
user_id: str
team_id: str


class Plan(BaseDataModel):
"""Represents a plan containing multiple steps."""

data_type: Literal["plan"] = Field("plan", Literal=True)
data_type: Literal[DataType.plan] = Field(DataType.plan, Literal=True)
plan_id: str
session_id: str
user_id: str
Expand All @@ -129,7 +129,7 @@ class Plan(BaseDataModel):
class Step(BaseDataModel):
"""Represents an individual step (task) within a plan."""

data_type: Literal["step"] = Field("step", Literal=True)
data_type: Literal[DataType.step] = Field(DataType.step, Literal=True)
plan_id: str
session_id: str # Partition key
user_id: str
Expand Down Expand Up @@ -181,7 +181,7 @@ class TeamConfiguration(BaseDataModel):
"""Represents a team configuration stored in the database."""

team_id: str
data_type: Literal["team_config"] = Field("team_config", Literal=True)
data_type: Literal[DataType.team_config] = Field(DataType.team_config, Literal=True)
session_id: str # Partition key
name: str
status: str
Expand Down Expand Up @@ -232,9 +232,11 @@ def update_step_counts(self):
self.completed = status_counts[StepStatus.completed]
self.failed = status_counts[StepStatus.failed]

# Mark the plan as complete if the sum of completed and failed steps equals the total number of steps
if self.completed + self.failed == self.total_steps:

if self.total_steps > 0 and (self.completed + self.failed) == self.total_steps:
self.overall_status = PlanStatus.completed
# Mark the plan as complete if the sum of completed and failed steps equals the total number of steps



# Message classes for communication between agents
Expand Down
29 changes: 18 additions & 11 deletions src/backend/v3/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from common.utils.utils_date import format_dates_in_messages
from common.config.app_config import config
from v3.common.services.plan_service import PlanService
import v3.models.messages as messages
from auth.auth_utils import get_authenticated_user_details
from common.database.database_factory import DatabaseFactory
Expand Down Expand Up @@ -380,16 +381,12 @@ async def plan_approval(
# orchestration_config.plans[human_feedback.m_plan_id],
# )
try:
plan = orchestration_config.plans[human_feedback.m_plan_id]
if hasattr(plan, "plan_id"):
print(
"Updated orchestration config:",
orchestration_config.plans[human_feedback.m_plan_id],
)
plan.plan_id = human_feedback.plan_id
orchestration_config.plans[human_feedback.m_plan_id] = plan
result = await PlanService.handle_plan_approval(human_feedback, user_id)
print("Plan approval processed:", result)
except ValueError as ve:
print(f"ValueError processing plan approval: {ve}")
except Exception as e:
print(f"Error processing plan approval: {e}")
print(f"Error processing plan approval: {e}")
track_event_if_configured(
"PlanApprovalReceived",
{
Expand All @@ -400,6 +397,7 @@ async def plan_approval(
"feedback": human_feedback.feedback,
},
)

return {"status": "approval recorded"}
else:
logging.warning(
Expand Down Expand Up @@ -1030,9 +1028,18 @@ async def get_plans(request: Request):
if not current_team:
return []

all_plans = await memory_store.get_all_plans_by_team_id(team_id=current_team.id)
all_plans = await memory_store.get_all_plans_by_team_id(team_id=current_team.team_id)

return all_plans
steps_for_all_plans = []
# Create list of PlanWithSteps and update step counts
list_of_plans_with_steps = []
for plan in all_plans:
plan_with_steps = PlanWithSteps(**plan.model_dump(), steps=[])
plan_with_steps.overall_status
plan_with_steps.update_step_counts()
list_of_plans_with_steps.append(plan_with_steps)

return list_of_plans_with_steps


# Get plans is called in the initial side rendering of the frontend
Expand Down
53 changes: 53 additions & 0 deletions src/backend/v3/common/services/plan_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging
from typing import Dict, Any, Optional
from common.database.database_factory import DatabaseFactory
from common.database.database_base import DatabaseBase
import v3.models.messages as messages
from v3.config.settings import orchestration_config
from common.utils.event_utils import track_event_if_configured

logger = logging.getLogger(__name__)

class PlanService:


@staticmethod
async def handle_plan_approval(human_feedback: messages.PlanApprovalResponse, user_id: str) -> bool:
"""
Process a PlanApprovalResponse coming from the client.

Args:
feedback: messages.PlanApprovalResponse (contains m_plan_id, plan_id, approved, feedback)
user_id: authenticated user id

Returns:
dict with status and metadata

Raises:
ValueError on invalid state
"""
if orchestration_config is None:
return False
try:
mplan = orchestration_config.plans[human_feedback.m_plan_id]
if hasattr(mplan, "plan_id"):
print(
"Updated orchestration config:",
orchestration_config.plans[human_feedback.m_plan_id],
)
mplan.plan_id = human_feedback.plan_id
orchestration_config.plans[human_feedback.m_plan_id] = mplan
memory_store = await DatabaseFactory.get_database(user_id=user_id)
plan = await memory_store.get_plan(human_feedback.plan_id)
if plan:
print("Retrieved plan from memory store:", plan)


else:
print("Plan not found in memory store.")
return False

except Exception as e:
print(f"Error processing plan approval: {e}")
return False
return True
21 changes: 21 additions & 0 deletions src/frontend/src/components/common/TeamSelected.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { TeamConfig } from "@/models";
import { Body1, Caption1 } from "@fluentui/react-components";

export interface TeamSelectedProps {
selectedTeam?: TeamConfig | null;
styles: { [key: string]: string };
}

const TeamSelected: React.FC<TeamSelectedProps> = ({ selectedTeam, styles }) => {
return (
<div className={styles.teamSelectorContent}>
<Caption1 className={styles.currentTeamLabel}>
Current Team
</Caption1>
<Body1 className={styles.currentTeamName}>
{selectedTeam ? selectedTeam.name : 'No team selected'}
</Body1>
</div>
);
}
export default TeamSelected;
Loading
Loading