Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions src/backend/v3/api/router.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import contextvars
import json
import logging
import uuid
Expand Down Expand Up @@ -31,7 +30,6 @@
from v3.common.services.team_service import TeamService
from v3.config.settings import (
connection_config,
current_user_id,
orchestration_config,
team_config,
)
Expand All @@ -57,8 +55,6 @@ async def start_comms(

user_id = user_id or "00000000-0000-0000-0000-000000000000"

current_user_id.set(user_id)

# Add to the connection manager for backend updates
connection_config.add_connection(
process_id=process_id, connection=websocket, user_id=user_id
Expand Down Expand Up @@ -90,7 +86,7 @@ async def start_comms(
logging.error(f"Error in WebSocket connection: {str(e)}")
finally:
# Always clean up the connection
await connection_config.close_connection(user_id)
await connection_config.close_connection(process_id=process_id)


@app_v3.get("/init_team")
Expand Down Expand Up @@ -304,18 +300,14 @@ async def process_request(
raise HTTPException(status_code=500, detail="Failed to create plan")

try:
current_user_id.set(user_id) # Set context
current_context = contextvars.copy_context() # Capture context
# background_tasks.add_task(
# lambda: current_context.run(lambda:OrchestrationManager().run_orchestration, user_id, input_task)
# )

async def run_with_context():
return await current_context.run(
OrchestrationManager().run_orchestration, user_id, input_task
)
async def run_orchestration_task():
await OrchestrationManager().run_orchestration(user_id, input_task)

background_tasks.add_task(run_with_context)
background_tasks.add_task(run_orchestration_task)

return {
"status": "Request started successfully",
Expand Down
13 changes: 2 additions & 11 deletions src/backend/v3/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
"""

import asyncio
import contextvars
import json
import logging
from typing import Dict, Optional
from typing import Dict

from common.config.app_config import config
from common.models.messages_kernel import TeamConfiguration
Expand All @@ -21,11 +20,6 @@

logger = logging.getLogger(__name__)

# Create a context variable to track current user
current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"current_user_id", default=None
)


class AzureConfig:
"""Azure OpenAI and authentication configuration."""
Expand Down Expand Up @@ -181,13 +175,10 @@ async def close_connection(self, process_id):
async def send_status_update_async(
self,
message: any,
user_id: Optional[str] = None,
user_id: str,
message_type: WebsocketMessageType = WebsocketMessageType.SYSTEM_MESSAGE,
):
"""Send a status update to a specific client."""
# If no process_id provided, get from context
if user_id is None:
user_id = current_user_id.get()

if not user_id:
logger.warning("No user_id available for WebSocket message")
Expand Down
16 changes: 6 additions & 10 deletions src/backend/v3/magentic_agents/magentic_agent_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from common.config.app_config import config
from common.models.messages_kernel import TeamConfiguration
from v3.config.settings import current_user_id
from v3.magentic_agents.foundry_agent import FoundryAgentTemplate
from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig

Expand Down Expand Up @@ -40,13 +39,12 @@ def __init__(self):
# data = json.load(f)
# return json.loads(json.dumps(data), object_hook=lambda d: SimpleNamespace(**d))

async def create_agent_from_config(
self, agent_obj: SimpleNamespace
) -> Union[FoundryAgentTemplate, ReasoningAgentTemplate, ProxyAgent]:
async def create_agent_from_config(self, user_id: str, agent_obj: SimpleNamespace) -> Union[FoundryAgentTemplate, ReasoningAgentTemplate, ProxyAgent]:
"""
Create an agent from configuration object.

Args:
user_id: User ID
agent_obj: Agent object from parsed JSON (SimpleNamespace)
team_model: Model name to determine which template to use

Expand All @@ -62,7 +60,6 @@ async def create_agent_from_config(

if not deployment_name and agent_obj.name.lower() == "proxyagent":
self.logger.info("Creating ProxyAgent")
user_id = current_user_id.get()
return ProxyAgent(user_id=user_id)

# Validate supported models
Expand Down Expand Up @@ -133,11 +130,12 @@ async def create_agent_from_config(
)
return agent

async def get_agents(self, team_config_input: TeamConfiguration) -> List:
async def get_agents(self, user_id: str, team_config_input: TeamConfiguration) -> List:
"""
Create and return a team of agents from JSON configuration.

Args:
user_id: User ID
team_config_input: team configuration object from cosmos db

Returns:
Expand All @@ -151,11 +149,9 @@ async def get_agents(self, team_config_input: TeamConfiguration) -> List:

for i, agent_cfg in enumerate(team_config_input.agents, 1):
try:
self.logger.info(
f"Creating agent {i}/{len(team_config_input.agents)}: {agent_cfg.name}"
)
self.logger.info(f"Creating agent {i}/{len(team_config_input.agents)}: {agent_cfg.name}")

agent = await self.create_agent_from_config(agent_cfg)
agent = await self.create_agent_from_config(user_id, agent_cfg)
initalized_agents.append(agent)
self._agent_list.append(agent) # Keep track for cleanup

Expand Down
31 changes: 13 additions & 18 deletions src/backend/v3/magentic_agents/proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,11 @@
)
from semantic_kernel.exceptions.agent_exceptions import AgentThreadOperationException
from typing_extensions import override
from v3.callbacks.response_handlers import (
agent_response_callback,
streaming_agent_response_callback,
)
from v3.config.settings import connection_config, current_user_id, orchestration_config
from v3.models.messages import (
UserClarificationRequest,
UserClarificationResponse,
WebsocketMessageType,
)
from v3.callbacks.response_handlers import (agent_response_callback,
streaming_agent_response_callback)
from v3.config.settings import connection_config, orchestration_config
from v3.models.messages import (UserClarificationRequest,
UserClarificationResponse, WebsocketMessageType)


class DummyAgentThread(AgentThread):
Expand Down Expand Up @@ -110,13 +105,13 @@ class ProxyAgent(Agent):
"""Simple proxy agent that prompts for human clarification."""

# Declare as Pydantic field
user_id: Optional[str] = Field(
user_id: str = Field(
default=None, description="User ID for WebSocket messaging"
)

def __init__(self, user_id: str = None, **kwargs):
# Get user_id from parameter or context, fallback to empty string
effective_user_id = user_id or current_user_id.get() or ""
def __init__(self, user_id: str, **kwargs):
# Get user_id from parameter, fallback to empty string
effective_user_id = user_id or ""
super().__init__(
name="ProxyAgent",
description="Call this agent when you need to clarify requests by asking the human user for more information. Ask it for more details about any unclear requirements, missing information, or if you need the user to elaborate on any aspect of the task.",
Expand All @@ -139,15 +134,15 @@ def _create_message_content(
async def _trigger_response_callbacks(self, message_content: ChatMessageContent):
"""Manually trigger the same response callbacks used by other agents."""
# Get current user_id dynamically instead of using stored value
current_user = current_user_id.get() or self.user_id or ""
current_user = self.user_id or ""

# Trigger the standard agent response callback
agent_response_callback(message_content, current_user)

async def _trigger_streaming_callbacks(self, content: str, is_final: bool = False):
"""Manually trigger streaming callbacks for real-time updates."""
# Get current user_id dynamically instead of using stored value
current_user = current_user_id.get() or self.user_id or ""
current_user = self.user_id or ""
streaming_message = StreamingChatMessageContent(
role=AuthorRole.ASSISTANT, content=content, name=self.name, choice_index=0
)
Expand Down Expand Up @@ -181,7 +176,7 @@ async def invoke(
"type": WebsocketMessageType.USER_CLARIFICATION_REQUEST,
"data": clarification_message,
},
user_id=current_user_id.get(),
user_id=self.user_id,
message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST,
)

Expand Down Expand Up @@ -238,7 +233,7 @@ async def invoke_stream(
"type": WebsocketMessageType.USER_CLARIFICATION_REQUEST,
"data": clarification_message,
},
user_id=current_user_id.get(),
user_id=self.user_id,
message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST,
)

Expand Down
27 changes: 19 additions & 8 deletions src/backend/v3/orchestration/human_approval_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT,
)
from semantic_kernel.contents import ChatMessageContent
from v3.config.settings import connection_config, current_user_id, orchestration_config
from v3.config.settings import connection_config, orchestration_config
from v3.models.models import MPlan
from v3.orchestration.helper.plan_to_mplan_converter import PlanToMPlanConverter
from v3.orchestration.helper.plan_to_mplan_converter import \
PlanToMPlanConverter

# Using a module level logger to avoid pydantic issues around inherited fields
logger = logging.getLogger(__name__)
Expand All @@ -38,9 +39,17 @@ class HumanApprovalMagenticManager(StandardMagenticManager):
# Define Pydantic fields to avoid validation errors
approval_enabled: bool = True
magentic_plan: Optional[MPlan] = None
current_user_id: Optional[str] = None
current_user_id: str

def __init__(self, user_id: str, *args, **kwargs):
"""
Initialize the HumanApprovalMagenticManager.
Args:
user_id: ID of the user to associate with this orchestration instance.
*args: Additional positional arguments for the parent StandardMagenticManager.
**kwargs: Additional keyword arguments for the parent StandardMagenticManager.
"""

def __init__(self, *args, **kwargs):
# Remove any custom kwargs before passing to parent

plan_append = """
Expand Down Expand Up @@ -76,6 +85,8 @@ def __init__(self, *args, **kwargs):
)
kwargs["final_answer_prompt"] = ORCHESTRATOR_FINAL_ANSWER_PROMPT + final_append

kwargs['current_user_id'] = user_id

super().__init__(*args, **kwargs)

async def plan(self, magentic_context: MagenticContext) -> Any:
Expand All @@ -100,7 +111,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any:

self.magentic_plan = self.plan_to_obj(magentic_context, self.task_ledger)

self.magentic_plan.user_id = current_user_id.get()
self.magentic_plan.user_id = self.current_user_id

# Request approval from the user before executing the plan
approval_message = messages.PlanApprovalRequest(
Expand All @@ -124,7 +135,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any:
# The user_id will be automatically retrieved from context
await connection_config.send_status_update_async(
message=approval_message,
user_id=current_user_id.get(),
user_id=self.current_user_id,
message_type=messages.WebsocketMessageType.PLAN_APPROVAL_REQUEST,
)

Expand All @@ -141,7 +152,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any:
"type": messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE,
"data": approval_response,
},
user_id=current_user_id.get(),
user_id=self.current_user_id,
message_type=messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE,
)
raise Exception("Plan execution cancelled by user")
Expand Down Expand Up @@ -170,7 +181,7 @@ async def create_progress_ledger(

await connection_config.send_status_update_async(
message=final_message,
user_id=current_user_id.get(),
user_id=self.current_user_id,
message_type=messages.WebsocketMessageType.FINAL_RESULT_MESSAGE,
)

Expand Down
30 changes: 8 additions & 22 deletions src/backend/v3/orchestration/orchestration_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Copyright (c) Microsoft. All rights reserved.
"""Orchestration manager to handle the orchestration logic."""
import asyncio
import contextvars
import logging
import uuid
from contextvars import ContextVar
from typing import List, Optional

from azure.identity import DefaultAzureCredential as SyncDefaultAzureCredential
Expand All @@ -15,27 +13,16 @@

# Create custom execution settings to fix schema issues
from semantic_kernel.connectors.ai.open_ai import (
AzureChatCompletion,
OpenAIChatPromptExecutionSettings,
)
from semantic_kernel.contents import ChatMessageContent, StreamingChatMessageContent
from v3.callbacks.response_handlers import (
agent_response_callback,
streaming_agent_response_callback,
)
from v3.config.settings import (
connection_config,
orchestration_config,
)
AzureChatCompletion, OpenAIChatPromptExecutionSettings)
from semantic_kernel.contents import (ChatMessageContent,
StreamingChatMessageContent)
from v3.callbacks.response_handlers import (agent_response_callback,
streaming_agent_response_callback)
from v3.config.settings import connection_config, orchestration_config
from v3.magentic_agents.magentic_agent_factory import MagenticAgentFactory
from v3.models.messages import WebsocketMessageType
from v3.orchestration.human_approval_manager import HumanApprovalMagenticManager

# Context variable to hold the current user ID
current_user_id: ContextVar[Optional[str]] = contextvars.ContextVar(
"current_user_id", default=None
)


class OrchestrationManager:
"""Manager for handling orchestration logic."""
Expand Down Expand Up @@ -69,6 +56,7 @@ def get_token():
magentic_orchestration = MagenticOrchestration(
members=agents,
manager=HumanApprovalMagenticManager(
user_id=user_id,
chat_completion_service=AzureChatCompletion(
deployment_name=config.AZURE_OPENAI_DEPLOYMENT_NAME,
endpoint=config.AZURE_OPENAI_ENDPOINT,
Expand Down Expand Up @@ -122,15 +110,14 @@ async def get_current_or_new_orchestration(
except Exception as e:
cls.logger.error("Error closing agent: %s", e)
factory = MagenticAgentFactory()
agents = await factory.get_agents(team_config_input=team_config)
agents = await factory.get_agents(user_id=user_id, team_config_input=team_config)
orchestration_config.orchestrations[user_id] = await cls.init_orchestration(
agents, user_id
)
return orchestration_config.get_current_orchestration(user_id)

async def run_orchestration(self, user_id, input_task) -> None:
"""Run the orchestration with user input loop."""
token = current_user_id.set(user_id)

job_id = str(uuid.uuid4())
orchestration_config.approvals[job_id] = None
Expand Down Expand Up @@ -190,4 +177,3 @@ async def run_orchestration(self, user_id, input_task) -> None:
self.logger.error(f"Unexpected error: {e}")
finally:
await runtime.stop_when_idle()
current_user_id.reset(token)
Loading