Skip to content

Commit 2efbece

Browse files
authored
Merge pull request #455 from microsoft/macae-v3-fr-dev-92
Macae v3 fr dev 92
2 parents 59867a4 + 123a421 commit 2efbece

30 files changed

+1356
-1282
lines changed

src/backend/common/database/cosmosdb.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ class CosmosDBClient(DatabaseBase):
3838
"""CosmosDB implementation of the database interface."""
3939

4040
MODEL_CLASS_MAPPING = {
41-
"session": Session,
42-
"plan": Plan,
43-
"step": Step,
44-
"agent_message": AgentMessage,
45-
"team_config": TeamConfiguration,
46-
"user_current_team": UserCurrentTeam,
41+
DataType.session: Session,
42+
DataType.plan: Plan,
43+
DataType.step: Step,
44+
DataType.agent_message: AgentMessage,
45+
DataType.team_config: TeamConfiguration,
46+
DataType.user_current_team: UserCurrentTeam,
4747
}
4848

4949
def __init__(
@@ -200,7 +200,7 @@ async def get_session(self, session_id: str) -> Optional[Session]:
200200
query = "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type"
201201
parameters = [
202202
{"name": "@id", "value": session_id},
203-
{"name": "@data_type", "value": "session"},
203+
{"name": "@data_type", "value": DataType.session},
204204
]
205205
results = await self.query_items(query, parameters, Session)
206206
return results[0] if results else None
@@ -210,7 +210,7 @@ async def get_all_sessions(self) -> List[Session]:
210210
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
211211
parameters = [
212212
{"name": "@user_id", "value": self.user_id},
213-
{"name": "@data_type", "value": "session"},
213+
{"name": "@data_type", "value": DataType.session},
214214
]
215215
return await self.query_items(query, parameters, Session)
216216

@@ -230,7 +230,7 @@ async def get_plan_by_session(self, session_id: str) -> Optional[Plan]:
230230
)
231231
parameters = [
232232
{"name": "@session_id", "value": session_id},
233-
{"name": "@data_type", "value": "plan"},
233+
{"name": "@data_type", "value": DataType.plan},
234234
]
235235
results = await self.query_items(query, parameters, Plan)
236236
return results[0] if results else None
@@ -240,7 +240,7 @@ async def get_plan_by_plan_id(self, plan_id: str) -> Optional[Plan]:
240240
query = "SELECT * FROM c WHERE c.id=@plan_id AND c.data_type=@data_type"
241241
parameters = [
242242
{"name": "@plan_id", "value": plan_id},
243-
{"name": "@data_type", "value": "plan"},
243+
{"name": "@data_type", "value": DataType.plan},
244244
{"name": "@user_id", "value": self.user_id},
245245
]
246246
results = await self.query_items(query, parameters, Plan)
@@ -255,7 +255,7 @@ async def get_all_plans(self) -> List[Plan]:
255255
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type"
256256
parameters = [
257257
{"name": "@user_id", "value": self.user_id},
258-
{"name": "@data_type", "value": "plan"},
258+
{"name": "@data_type", "value": DataType.plan},
259259
]
260260
return await self.query_items(query, parameters, Plan)
261261

@@ -265,7 +265,7 @@ async def get_all_plans_by_team_id(self, team_id: str) -> List[Plan]:
265265
parameters = [
266266
{"name": "@user_id", "value": self.user_id},
267267
{"name": "@team_id", "value": team_id},
268-
{"name": "@data_type", "value": "plan"},
268+
{"name": "@data_type", "value": DataType.plan},
269269
]
270270
return await self.query_items(query, parameters, Plan)
271271

@@ -283,7 +283,7 @@ async def get_steps_by_plan(self, plan_id: str) -> List[Step]:
283283
query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type ORDER BY c.timestamp"
284284
parameters = [
285285
{"name": "@plan_id", "value": plan_id},
286-
{"name": "@data_type", "value": "step"},
286+
{"name": "@data_type", "value": DataType.step},
287287
]
288288
return await self.query_items(query, parameters, Step)
289289

@@ -293,7 +293,7 @@ async def get_step(self, step_id: str, session_id: str) -> Optional[Step]:
293293
parameters = [
294294
{"name": "@step_id", "value": step_id},
295295
{"name": "@session_id", "value": session_id},
296-
{"name": "@data_type", "value": "step"},
296+
{"name": "@data_type", "value": DataType.step},
297297
]
298298
results = await self.query_items(query, parameters, Step)
299299
return results[0] if results else None
@@ -312,7 +312,7 @@ async def get_team(self, team_id: str) -> Optional[TeamConfiguration]:
312312
query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type"
313313
parameters = [
314314
{"name": "@team_id", "value": team_id},
315-
{"name": "@data_type", "value": "team_config"},
315+
{"name": "@data_type", "value": DataType.team_config},
316316
]
317317
teams = await self.query_items(query, parameters, TeamConfiguration)
318318
return teams[0] if teams else None
@@ -329,7 +329,7 @@ async def get_team_by_id(self, id: str) -> Optional[TeamConfiguration]:
329329
query = "SELECT * FROM c WHERE c.id=@id AND c.data_type=@data_type"
330330
parameters = [
331331
{"name": "@id", "value": id},
332-
{"name": "@data_type", "value": "team_config"},
332+
{"name": "@data_type", "value": DataType.team_config},
333333
]
334334
teams = await self.query_items(query, parameters, TeamConfiguration)
335335
return teams[0] if teams else None
@@ -346,7 +346,7 @@ async def get_all_teams_by_user(self, user_id: str) -> List[TeamConfiguration]:
346346
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type ORDER BY c.created DESC"
347347
parameters = [
348348
{"name": "@user_id", "value": user_id},
349-
{"name": "@data_type", "value": "team_config"},
349+
{"name": "@data_type", "value": DataType.team_config},
350350
]
351351
teams = await self.query_items(query, parameters, TeamConfiguration)
352352
return teams
@@ -452,7 +452,7 @@ async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:
452452

453453
query = "SELECT * FROM c WHERE c.data_type=@data_type AND c.user_id=@user_id"
454454
parameters = [
455-
{"name": "@data_type", "value": "user_current_team"},
455+
{"name": "@data_type", "value": DataType.user_current_team},
456456
{"name": "@user_id", "value": user_id},
457457
]
458458

src/backend/common/models/messages_kernel.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ class DataType(str, Enum):
1212
session = "session"
1313
plan = "plan"
1414
step = "step"
15-
message = "agent_message"
16-
team = "team_config"
15+
agent_message = "agent_message"
16+
team_config = "team_config"
1717
user_current_team = "user_current_team"
1818
m_plan = "m_plan"
1919
m_plan_step = "m_plan_step"
@@ -84,7 +84,7 @@ class BaseDataModel(KernelBaseModel):
8484
class AgentMessage(BaseDataModel):
8585
"""Base class for messages sent between agents."""
8686

87-
data_type: Literal["agent_message"] = Field("agent_message", Literal=True)
87+
data_type: Literal[DataType.agent_message] = Field(DataType.agent_message, Literal=True)
8888
session_id: str
8989
user_id: str
9090
plan_id: str
@@ -96,7 +96,7 @@ class AgentMessage(BaseDataModel):
9696
class Session(BaseDataModel):
9797
"""Represents a user session."""
9898

99-
data_type: Literal["session"] = Field("session", Literal=True)
99+
data_type: Literal[DataType.session] = Field(DataType.session, Literal=True)
100100
user_id: str
101101
current_status: str
102102
message_to_user: Optional[str] = None
@@ -105,15 +105,15 @@ class Session(BaseDataModel):
105105
class UserCurrentTeam(BaseDataModel):
106106
"""Represents the current team of a user."""
107107

108-
data_type: Literal["user_current_team"] = Field("user_current_team", Literal=True)
108+
data_type: Literal[DataType.user_current_team] = Field(DataType.user_current_team, Literal=True)
109109
user_id: str
110110
team_id: str
111111

112112

113113
class Plan(BaseDataModel):
114114
"""Represents a plan containing multiple steps."""
115115

116-
data_type: Literal["plan"] = Field("plan", Literal=True)
116+
data_type: Literal[DataType.plan] = Field(DataType.plan, Literal=True)
117117
plan_id: str
118118
session_id: str
119119
user_id: str
@@ -129,7 +129,7 @@ class Plan(BaseDataModel):
129129
class Step(BaseDataModel):
130130
"""Represents an individual step (task) within a plan."""
131131

132-
data_type: Literal["step"] = Field("step", Literal=True)
132+
data_type: Literal[DataType.step] = Field(DataType.step, Literal=True)
133133
plan_id: str
134134
session_id: str # Partition key
135135
user_id: str
@@ -181,7 +181,7 @@ class TeamConfiguration(BaseDataModel):
181181
"""Represents a team configuration stored in the database."""
182182

183183
team_id: str
184-
data_type: Literal["team_config"] = Field("team_config", Literal=True)
184+
data_type: Literal[DataType.team_config] = Field(DataType.team_config, Literal=True)
185185
session_id: str # Partition key
186186
name: str
187187
status: str
@@ -232,9 +232,11 @@ def update_step_counts(self):
232232
self.completed = status_counts[StepStatus.completed]
233233
self.failed = status_counts[StepStatus.failed]
234234

235-
# Mark the plan as complete if the sum of completed and failed steps equals the total number of steps
236-
if self.completed + self.failed == self.total_steps:
235+
236+
if self.total_steps > 0 and (self.completed + self.failed) == self.total_steps:
237237
self.overall_status = PlanStatus.completed
238+
# Mark the plan as complete if the sum of completed and failed steps equals the total number of steps
239+
238240

239241

240242
# Message classes for communication between agents

src/backend/v3/api/router.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from common.utils.utils_date import format_dates_in_messages
99
from common.config.app_config import config
10+
from v3.common.services.plan_service import PlanService
1011
import v3.models.messages as messages
1112
from auth.auth_utils import get_authenticated_user_details
1213
from common.database.database_factory import DatabaseFactory
@@ -380,16 +381,12 @@ async def plan_approval(
380381
# orchestration_config.plans[human_feedback.m_plan_id],
381382
# )
382383
try:
383-
plan = orchestration_config.plans[human_feedback.m_plan_id]
384-
if hasattr(plan, "plan_id"):
385-
print(
386-
"Updated orchestration config:",
387-
orchestration_config.plans[human_feedback.m_plan_id],
388-
)
389-
plan.plan_id = human_feedback.plan_id
390-
orchestration_config.plans[human_feedback.m_plan_id] = plan
384+
result = await PlanService.handle_plan_approval(human_feedback, user_id)
385+
print("Plan approval processed:", result)
386+
except ValueError as ve:
387+
print(f"ValueError processing plan approval: {ve}")
391388
except Exception as e:
392-
print(f"Error processing plan approval: {e}")
389+
print(f"Error processing plan approval: {e}")
393390
track_event_if_configured(
394391
"PlanApprovalReceived",
395392
{
@@ -400,6 +397,7 @@ async def plan_approval(
400397
"feedback": human_feedback.feedback,
401398
},
402399
)
400+
403401
return {"status": "approval recorded"}
404402
else:
405403
logging.warning(
@@ -1030,9 +1028,18 @@ async def get_plans(request: Request):
10301028
if not current_team:
10311029
return []
10321030

1033-
all_plans = await memory_store.get_all_plans_by_team_id(team_id=current_team.id)
1031+
all_plans = await memory_store.get_all_plans_by_team_id(team_id=current_team.team_id)
10341032

1035-
return all_plans
1033+
steps_for_all_plans = []
1034+
# Create list of PlanWithSteps and update step counts
1035+
list_of_plans_with_steps = []
1036+
for plan in all_plans:
1037+
plan_with_steps = PlanWithSteps(**plan.model_dump(), steps=[])
1038+
plan_with_steps.overall_status
1039+
plan_with_steps.update_step_counts()
1040+
list_of_plans_with_steps.append(plan_with_steps)
1041+
1042+
return list_of_plans_with_steps
10361043

10371044

10381045
# Get plans is called in the initial side rendering of the frontend
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import logging
2+
from typing import Dict, Any, Optional
3+
from common.database.database_factory import DatabaseFactory
4+
from common.database.database_base import DatabaseBase
5+
import v3.models.messages as messages
6+
from v3.config.settings import orchestration_config
7+
from common.utils.event_utils import track_event_if_configured
8+
9+
logger = logging.getLogger(__name__)
10+
11+
class PlanService:
12+
13+
14+
@staticmethod
15+
async def handle_plan_approval(human_feedback: messages.PlanApprovalResponse, user_id: str) -> bool:
16+
"""
17+
Process a PlanApprovalResponse coming from the client.
18+
19+
Args:
20+
feedback: messages.PlanApprovalResponse (contains m_plan_id, plan_id, approved, feedback)
21+
user_id: authenticated user id
22+
23+
Returns:
24+
dict with status and metadata
25+
26+
Raises:
27+
ValueError on invalid state
28+
"""
29+
if orchestration_config is None:
30+
return False
31+
try:
32+
mplan = orchestration_config.plans[human_feedback.m_plan_id]
33+
if hasattr(mplan, "plan_id"):
34+
print(
35+
"Updated orchestration config:",
36+
orchestration_config.plans[human_feedback.m_plan_id],
37+
)
38+
mplan.plan_id = human_feedback.plan_id
39+
orchestration_config.plans[human_feedback.m_plan_id] = mplan
40+
memory_store = await DatabaseFactory.get_database(user_id=user_id)
41+
plan = await memory_store.get_plan(human_feedback.plan_id)
42+
if plan:
43+
print("Retrieved plan from memory store:", plan)
44+
45+
46+
else:
47+
print("Plan not found in memory store.")
48+
return False
49+
50+
except Exception as e:
51+
print(f"Error processing plan approval: {e}")
52+
return False
53+
return True
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import { TeamConfig } from "@/models";
2+
import { Body1, Caption1 } from "@fluentui/react-components";
3+
4+
export interface TeamSelectedProps {
5+
selectedTeam?: TeamConfig | null;
6+
styles: { [key: string]: string };
7+
}
8+
9+
const TeamSelected: React.FC<TeamSelectedProps> = ({ selectedTeam, styles }) => {
10+
return (
11+
<div className={styles.teamSelectorContent}>
12+
<Caption1 className={styles.currentTeamLabel}>
13+
Current Team
14+
</Caption1>
15+
<Body1 className={styles.currentTeamName}>
16+
{selectedTeam ? selectedTeam.name : 'No team selected'}
17+
</Body1>
18+
</div>
19+
);
20+
}
21+
export default TeamSelected;

0 commit comments

Comments
 (0)