Skip to content

Commit adfb483

Browse files
authored
Merge pull request microsoft#442 from microsoft/macae-v3-fr-dev-92
feat: Macae v3 fr dev 92
2 parents dbbd5e1 + 09770e4 commit adfb483

File tree

8 files changed

+202
-267
lines changed

8 files changed

+202
-267
lines changed

src/backend/app_kernel.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -598,9 +598,7 @@ async def approve_step_endpoint(
598598
@app.get("/api/plans")
599599
async def get_plans(
600600
request: Request,
601-
session_id: Optional[str] = Query(None),
602601
plan_id: Optional[str] = Query(None),
603-
team_id: Optional[str] = Query(None),
604602
):
605603
"""
606604
Retrieve plans for the current user.
@@ -673,20 +671,7 @@ async def get_plans(
673671

674672
# # Initialize memory context
675673
memory_store = await DatabaseFactory.get_database(user_id=user_id)
676-
if session_id:
677-
plan = await memory_store.get_plan_by_session(session_id=session_id)
678-
if not plan:
679-
track_event_if_configured(
680-
"GetPlanBySessionNotFound",
681-
{"status_code": 400, "detail": "Plan not found"},
682-
)
683-
raise HTTPException(status_code=404, detail="Plan not found")
684674

685-
# Use get_steps_by_plan to match the original implementation
686-
steps = await memory_store.get_steps_by_plan(plan_id=plan.id)
687-
plan_with_steps = PlanWithSteps(**plan.model_dump(), steps=steps)
688-
plan_with_steps.update_step_counts()
689-
return [plan_with_steps]
690675
if plan_id:
691676
plan = await memory_store.get_plan_by_plan_id(plan_id=plan_id)
692677
if not plan:
@@ -712,7 +697,11 @@ async def get_plans(
712697

713698
return [plan_with_steps, formatted_messages]
714699

715-
all_plans = await memory_store.get_all_plans()
700+
current_team = await memory_store.get_current_team(user_id=user_id)
701+
if not current_team:
702+
return []
703+
704+
all_plans = await memory_store.get_all_plans_by_team_id(team_id=current_team.id)
716705
# Fetch steps for all plans concurrently
717706
steps_for_all_plans = await asyncio.gather(
718707
*[memory_store.get_steps_by_plan(plan_id=plan.id) for plan in all_plans]

src/backend/common/database/cosmosdb.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,16 @@ async def get_all_plans(self) -> List[Plan]:
259259
]
260260
return await self.query_items(query, parameters, Plan)
261261

262+
async def get_all_plans_by_team_id(self, team_id: str) -> List[Plan]:
263+
"""Retrieve all plans for a specific team."""
264+
query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type and c.user_id=@user_id"
265+
parameters = [
266+
{"name": "@user_id", "value": self.user_id},
267+
{"name": "@team_id", "value": team_id},
268+
{"name": "@data_type", "value": "plan"},
269+
]
270+
return await self.query_items(query, parameters, Plan)
271+
262272
# Step Operations
263273
async def add_step(self, step: Step) -> None:
264274
"""Add a step to CosmosDB."""
@@ -434,23 +444,22 @@ async def update_team(self, team: TeamConfiguration) -> None:
434444
"""
435445
await self.update_item(team)
436446

437-
async def get_current_team(self, user_id: str, team_id: str) -> UserCurrentTeam:
447+
async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:
438448
"""Retrieve the current team for a user."""
439449
await self._ensure_initialized()
440450
if self.container is None:
441451
return None
442452

443-
query = "SELECT * FROM c WHERE c.user_id=@user_id AND c.is_default=true"
453+
query = "SELECT * FROM c WHERE c.data_type=@data_type AND c.user_id=@user_id"
444454
parameters = [
455+
{"name": "@data_type", "value": "user_current_team"},
445456
{"name": "@user_id", "value": user_id},
446-
{"name": "@team_id", "value": team_id},
447457
]
448458

449-
items = self.container.query_items(query=query, parameters=parameters)
450-
async for item in items:
451-
return UserCurrentTeam(**item)
459+
# Get the appropriate model class
460+
teams = await self.query_items(query, parameters, UserCurrentTeam)
461+
return teams[0] if teams else None
452462

453-
return None
454463
async def set_current_team(self, current_team: UserCurrentTeam) -> None:
455464
"""Set the current team for a user."""
456465
await self._ensure_initialized()

src/backend/common/database/database_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,10 @@ async def get_plan(self, plan_id: str) -> Optional[Plan]:
106106
async def get_all_plans(self) -> List[Plan]:
107107
"""Retrieve all plans for the user."""
108108
pass
109-
109+
@abstractmethod
110+
async def get_all_plans_by_team_id(self, team_id: str) -> List[Plan]:
111+
"""Retrieve all plans for a specific team."""
112+
pass
110113
@abstractmethod
111114
async def get_data_by_type_and_session_id(
112115
self, data_type: str, session_id: str
@@ -192,7 +195,7 @@ async def get_steps_for_plan(self, plan_id: str) -> List[Step]:
192195
pass
193196

194197
@abstractmethod
195-
async def get_current_team(self, user_id: str, team_id: str) -> UserCurrentTeam:
198+
async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]:
196199
"""Retrieve the current team for a user."""
197200
pass
198201

src/backend/v3/api/router.py

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
TeamSelectionRequest, UserCurrentTeam)
1414
from common.utils.event_utils import track_event_if_configured
1515
from common.utils.utils_kernel import rai_success, rai_validate_team_config
16-
from fastapi import (APIRouter, BackgroundTasks, Depends, FastAPI, File,
16+
from fastapi import (APIRouter, BackgroundTasks, Depends, FastAPI, File, Query,
1717
HTTPException, Request, UploadFile, WebSocket,
1818
WebSocketDisconnect)
1919
from kernel_agents.agent_factory import AgentFactory
@@ -103,7 +103,11 @@ async def init_team(
103103
# Initialize memory store and service
104104
memory_store = await DatabaseFactory.get_database(user_id=user_id)
105105
team_service = TeamService(memory_store)
106-
106+
user_current_team = await memory_store.get_current_team(user_id=user_id)
107+
if not user_current_team:
108+
await team_service.handle_team_selection(user_id=user_id, team_id=init_team_id)
109+
else:
110+
init_team_id = user_current_team.team_id
107111
# Verify the team exists and user has access to it
108112
team_configuration = await team_service.get_team_configuration(init_team_id, user_id)
109113
if team_configuration is None:
@@ -120,7 +124,8 @@ async def init_team(
120124

121125
return {
122126
"status": "Request started successfully",
123-
"team_id": init_team_id
127+
"team_id": init_team_id,
128+
"team": team_configuration
124129
}
125130

126131
except Exception as e:
@@ -184,30 +189,30 @@ async def process_request(background_tasks: BackgroundTasks, input_task: InputTa
184189
"""
185190

186191

187-
# if not await rai_success(input_task.description, False):
188-
# track_event_if_configured(
189-
# "RAI failed",
190-
# {
191-
# "status": "Plan not created - RAI check failed",
192-
# "description": input_task.description,
193-
# "session_id": input_task.session_id,
194-
# },
195-
# )
196-
# raise HTTPException(
197-
# status_code=400,
198-
# detail={
199-
# "error_type": "RAI_VALIDATION_FAILED",
200-
# "message": "Content Safety Check Failed",
201-
# "description": "Your request contains content that doesn't meet our safety guidelines. Please modify your request to ensure it's appropriate and try again.",
202-
# "suggestions": [
203-
# "Remove any potentially harmful, inappropriate, or unsafe content",
204-
# "Use more professional and constructive language",
205-
# "Focus on legitimate business or educational objectives",
206-
# "Ensure your request complies with content policies",
207-
# ],
208-
# "user_action": "Please revise your request and try again",
209-
# },
210-
# )
192+
if not await rai_success(input_task.description, False):
193+
track_event_if_configured(
194+
"RAI failed",
195+
{
196+
"status": "Plan not created - RAI check failed",
197+
"description": input_task.description,
198+
"session_id": input_task.session_id,
199+
},
200+
)
201+
raise HTTPException(
202+
status_code=400,
203+
detail={
204+
"error_type": "RAI_VALIDATION_FAILED",
205+
"message": "Content Safety Check Failed",
206+
"description": "Your request contains content that doesn't meet our safety guidelines. Please modify your request to ensure it's appropriate and try again.",
207+
"suggestions": [
208+
"Remove any potentially harmful, inappropriate, or unsafe content",
209+
"Use more professional and constructive language",
210+
"Focus on legitimate business or educational objectives",
211+
"Ensure your request complies with content policies",
212+
],
213+
"user_action": "Please revise your request and try again",
214+
},
215+
)
211216

212217
authenticated_user = get_authenticated_user_details(request_headers=request.headers)
213218
user_id = authenticated_user["user_principal_id"]
@@ -309,10 +314,14 @@ async def plan_approval(human_feedback: messages.PlanApprovalResponse, request:
309314
if user_id and human_feedback.plan_dot_id:
310315
if orchestration_config and human_feedback.plan_dot_id in orchestration_config.approvals:
311316
orchestration_config.approvals[human_feedback.plan_dot_id] = human_feedback.approved
317+
orchestration_config.plans[human_feedback.plan_dot_id]["plan_id"] = human_feedback.plan_id
318+
print("Plan approval received:", human_feedback)
319+
print("Updated orchestration config:", orchestration_config.plans[human_feedback.plan_dot_id])
312320
track_event_if_configured(
313321
"PlanApprovalReceived",
314322
{
315-
"plan_id": human_feedback.plan_dot_id,
323+
"plan_id": human_feedback.plan_id,
324+
"plan_dot_id": human_feedback.plan_dot_id,
316325
"approved": human_feedback.approved,
317326
"user_id": user_id,
318327
"feedback": human_feedback.feedback
@@ -351,7 +360,7 @@ async def user_clarification(human_feedback: messages.UserClarificationResponse,
351360

352361

353362
@app_v3.post("/upload_team_config")
354-
async def upload_team_config_endpoint(request: Request, file: UploadFile = File(...)):
363+
async def upload_team_config(request: Request, file: UploadFile = File(...), team_id: Optional[str] = Query(None),):
355364
"""
356365
Upload and save a team configuration JSON file.
357366
@@ -487,6 +496,10 @@ async def upload_team_config_endpoint(request: Request, file: UploadFile = File(
487496

488497
# Save the configuration
489498
try:
499+
print("Saving team configuration...", team_id)
500+
if team_id:
501+
team_config.team_id = team_id
502+
team_config.id = team_id # Ensure id is also set for updates
490503
team_id = await team_service.save_team_configuration(team_config)
491504
except ValueError as e:
492505
raise HTTPException(
@@ -521,7 +534,7 @@ async def upload_team_config_endpoint(request: Request, file: UploadFile = File(
521534

522535

523536
@app_v3.get("/team_configs")
524-
async def get_team_configs_endpoint(request: Request):
537+
async def get_team_configs(request: Request):
525538
"""
526539
Retrieve all team configurations for the current user.
527540
@@ -594,7 +607,7 @@ async def get_team_configs_endpoint(request: Request):
594607

595608

596609
@app_v3.get("/team_configs/{team_id}")
597-
async def get_team_config_by_id_endpoint(team_id: str, request: Request):
610+
async def get_team_config_by_id(team_id: str, request: Request):
598611
"""
599612
Retrieve a specific team configuration by ID.
600613
@@ -676,7 +689,7 @@ async def get_team_config_by_id_endpoint(team_id: str, request: Request):
676689

677690

678691
@app_v3.delete("/team_configs/{team_id}")
679-
async def delete_team_config_endpoint(team_id: str, request: Request):
692+
async def delete_team_config(team_id: str, request: Request):
680693
"""
681694
Delete a team configuration by ID.
682695
@@ -754,7 +767,7 @@ async def delete_team_config_endpoint(team_id: str, request: Request):
754767

755768

756769
@app_v3.get("/model_deployments")
757-
async def get_model_deployments_endpoint(request: Request):
770+
async def get_model_deployments(request: Request):
758771
"""
759772
Get information about available model deployments for debugging/validation.
760773
@@ -786,7 +799,7 @@ async def get_model_deployments_endpoint(request: Request):
786799

787800

788801
@app_v3.post("/select_team")
789-
async def select_team_endpoint(selection: TeamSelectionRequest, request: Request):
802+
async def select_team(selection: TeamSelectionRequest, request: Request):
790803
"""
791804
Select the current team for the user session.
792805
"""
@@ -870,7 +883,7 @@ async def select_team_endpoint(selection: TeamSelectionRequest, request: Request
870883

871884

872885
@app_v3.get("/search_indexes")
873-
async def get_search_indexes_endpoint(request: Request):
886+
async def get_search_indexes(request: Request):
874887
"""
875888
Get information about available search indexes for debugging/validation.
876889

src/backend/v3/common/services/team_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ async def handle_team_selection(self, user_id: str, team_id: str) -> bool:
255255
True if successful, False otherwise
256256
"""
257257
try:
258-
current_team = await self.memory_context.get_current_team(user_id, team_id)
258+
current_team = await self.memory_context.get_current_team(user_id)
259259

260260
if current_team is None:
261261
current_team = UserCurrentTeam(user_id=user_id, team_id=team_id)

src/backend/v3/models/messages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class PlanApprovalResponse:
5252
plan_dot_id: str
5353
approved: bool
5454
feedback: str | None = None
55+
plan_id: str | None = None
5556

5657
@dataclass(slots=True)
5758
class ReplanApprovalRequest:

0 commit comments

Comments
 (0)