diff --git a/src/backend/v3/api/router.py b/src/backend/v3/api/router.py index 4ed51a4a..11e83c53 100644 --- a/src/backend/v3/api/router.py +++ b/src/backend/v3/api/router.py @@ -1,5 +1,4 @@ import asyncio -import contextvars import json import logging import uuid @@ -31,7 +30,6 @@ from v3.common.services.team_service import TeamService from v3.config.settings import ( connection_config, - current_user_id, orchestration_config, team_config, ) @@ -57,8 +55,6 @@ async def start_comms( user_id = user_id or "00000000-0000-0000-0000-000000000000" - current_user_id.set(user_id) - # Add to the connection manager for backend updates connection_config.add_connection( process_id=process_id, connection=websocket, user_id=user_id @@ -90,7 +86,7 @@ async def start_comms( logging.error(f"Error in WebSocket connection: {str(e)}") finally: # Always clean up the connection - await connection_config.close_connection(user_id) + await connection_config.close_connection(process_id=process_id) @app_v3.get("/init_team") @@ -304,18 +300,14 @@ async def process_request( raise HTTPException(status_code=500, detail="Failed to create plan") try: - current_user_id.set(user_id) # Set context - current_context = contextvars.copy_context() # Capture context # background_tasks.add_task( # lambda: current_context.run(lambda:OrchestrationManager().run_orchestration, user_id, input_task) # ) - async def run_with_context(): - return await current_context.run( - OrchestrationManager().run_orchestration, user_id, input_task - ) + async def run_orchestration_task(): + await OrchestrationManager().run_orchestration(user_id, input_task) - background_tasks.add_task(run_with_context) + background_tasks.add_task(run_orchestration_task) return { "status": "Request started successfully", diff --git a/src/backend/v3/config/settings.py b/src/backend/v3/config/settings.py index fbf12e27..1dcbfbc6 100644 --- a/src/backend/v3/config/settings.py +++ b/src/backend/v3/config/settings.py @@ -4,10 +4,9 @@ """ import asyncio -import contextvars import json import logging -from typing import Dict, Optional +from typing import Dict from common.config.app_config import config from common.models.messages_kernel import TeamConfiguration @@ -21,11 +20,6 @@ logger = logging.getLogger(__name__) -# Create a context variable to track current user -current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( - "current_user_id", default=None -) - class AzureConfig: """Azure OpenAI and authentication configuration.""" @@ -181,13 +175,10 @@ async def close_connection(self, process_id): async def send_status_update_async( self, message: any, - user_id: Optional[str] = None, + user_id: str, message_type: WebsocketMessageType = WebsocketMessageType.SYSTEM_MESSAGE, ): """Send a status update to a specific client.""" - # If no process_id provided, get from context - if user_id is None: - user_id = current_user_id.get() if not user_id: logger.warning("No user_id available for WebSocket message") diff --git a/src/backend/v3/magentic_agents/magentic_agent_factory.py b/src/backend/v3/magentic_agents/magentic_agent_factory.py index dd13e27a..ed74e89b 100644 --- a/src/backend/v3/magentic_agents/magentic_agent_factory.py +++ b/src/backend/v3/magentic_agents/magentic_agent_factory.py @@ -8,7 +8,6 @@ from common.config.app_config import config from common.models.messages_kernel import TeamConfiguration -from v3.config.settings import current_user_id from v3.magentic_agents.foundry_agent import FoundryAgentTemplate from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig @@ -40,13 +39,12 @@ def __init__(self): # data = json.load(f) # return json.loads(json.dumps(data), object_hook=lambda d: SimpleNamespace(**d)) - async def create_agent_from_config( - self, agent_obj: SimpleNamespace - ) -> Union[FoundryAgentTemplate, ReasoningAgentTemplate, ProxyAgent]: + async def create_agent_from_config(self, user_id: str, agent_obj: SimpleNamespace) -> Union[FoundryAgentTemplate, ReasoningAgentTemplate, ProxyAgent]: """ Create an agent from configuration object. Args: + user_id: User ID agent_obj: Agent object from parsed JSON (SimpleNamespace) team_model: Model name to determine which template to use @@ -62,7 +60,6 @@ async def create_agent_from_config( if not deployment_name and agent_obj.name.lower() == "proxyagent": self.logger.info("Creating ProxyAgent") - user_id = current_user_id.get() return ProxyAgent(user_id=user_id) # Validate supported models @@ -133,11 +130,12 @@ async def create_agent_from_config( ) return agent - async def get_agents(self, team_config_input: TeamConfiguration) -> List: + async def get_agents(self, user_id: str, team_config_input: TeamConfiguration) -> List: """ Create and return a team of agents from JSON configuration. Args: + user_id: User ID team_config_input: team configuration object from cosmos db Returns: @@ -151,11 +149,9 @@ async def get_agents(self, team_config_input: TeamConfiguration) -> List: for i, agent_cfg in enumerate(team_config_input.agents, 1): try: - self.logger.info( - f"Creating agent {i}/{len(team_config_input.agents)}: {agent_cfg.name}" - ) + self.logger.info(f"Creating agent {i}/{len(team_config_input.agents)}: {agent_cfg.name}") - agent = await self.create_agent_from_config(agent_cfg) + agent = await self.create_agent_from_config(user_id, agent_cfg) initalized_agents.append(agent) self._agent_list.append(agent) # Keep track for cleanup diff --git a/src/backend/v3/magentic_agents/proxy_agent.py b/src/backend/v3/magentic_agents/proxy_agent.py index d18f5eb9..db952cc5 100644 --- a/src/backend/v3/magentic_agents/proxy_agent.py +++ b/src/backend/v3/magentic_agents/proxy_agent.py @@ -24,16 +24,11 @@ ) from semantic_kernel.exceptions.agent_exceptions import AgentThreadOperationException from typing_extensions import override -from v3.callbacks.response_handlers import ( - agent_response_callback, - streaming_agent_response_callback, -) -from v3.config.settings import connection_config, current_user_id, orchestration_config -from v3.models.messages import ( - UserClarificationRequest, - UserClarificationResponse, - WebsocketMessageType, -) +from v3.callbacks.response_handlers import (agent_response_callback, + streaming_agent_response_callback) +from v3.config.settings import connection_config, orchestration_config +from v3.models.messages import (UserClarificationRequest, + UserClarificationResponse, WebsocketMessageType) class DummyAgentThread(AgentThread): @@ -110,13 +105,13 @@ class ProxyAgent(Agent): """Simple proxy agent that prompts for human clarification.""" # Declare as Pydantic field - user_id: Optional[str] = Field( + user_id: str = Field( default=None, description="User ID for WebSocket messaging" ) - def __init__(self, user_id: str = None, **kwargs): - # Get user_id from parameter or context, fallback to empty string - effective_user_id = user_id or current_user_id.get() or "" + def __init__(self, user_id: str, **kwargs): + # Get user_id from parameter, fallback to empty string + effective_user_id = user_id or "" super().__init__( name="ProxyAgent", description="Call this agent when you need to clarify requests by asking the human user for more information. Ask it for more details about any unclear requirements, missing information, or if you need the user to elaborate on any aspect of the task.", @@ -139,7 +134,7 @@ def _create_message_content( async def _trigger_response_callbacks(self, message_content: ChatMessageContent): """Manually trigger the same response callbacks used by other agents.""" # Get current user_id dynamically instead of using stored value - current_user = current_user_id.get() or self.user_id or "" + current_user = self.user_id or "" # Trigger the standard agent response callback agent_response_callback(message_content, current_user) @@ -147,7 +142,7 @@ async def _trigger_response_callbacks(self, message_content: ChatMessageContent) async def _trigger_streaming_callbacks(self, content: str, is_final: bool = False): """Manually trigger streaming callbacks for real-time updates.""" # Get current user_id dynamically instead of using stored value - current_user = current_user_id.get() or self.user_id or "" + current_user = self.user_id or "" streaming_message = StreamingChatMessageContent( role=AuthorRole.ASSISTANT, content=content, name=self.name, choice_index=0 ) @@ -181,7 +176,7 @@ async def invoke( "type": WebsocketMessageType.USER_CLARIFICATION_REQUEST, "data": clarification_message, }, - user_id=current_user_id.get(), + user_id=self.user_id, message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST, ) @@ -238,7 +233,7 @@ async def invoke_stream( "type": WebsocketMessageType.USER_CLARIFICATION_REQUEST, "data": clarification_message, }, - user_id=current_user_id.get(), + user_id=self.user_id, message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST, ) diff --git a/src/backend/v3/orchestration/human_approval_manager.py b/src/backend/v3/orchestration/human_approval_manager.py index 7eef305a..1efd6c44 100644 --- a/src/backend/v3/orchestration/human_approval_manager.py +++ b/src/backend/v3/orchestration/human_approval_manager.py @@ -20,9 +20,10 @@ ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT, ) from semantic_kernel.contents import ChatMessageContent -from v3.config.settings import connection_config, current_user_id, orchestration_config +from v3.config.settings import connection_config, orchestration_config from v3.models.models import MPlan -from v3.orchestration.helper.plan_to_mplan_converter import PlanToMPlanConverter +from v3.orchestration.helper.plan_to_mplan_converter import \ + PlanToMPlanConverter # Using a module level logger to avoid pydantic issues around inherited fields logger = logging.getLogger(__name__) @@ -38,9 +39,17 @@ class HumanApprovalMagenticManager(StandardMagenticManager): # Define Pydantic fields to avoid validation errors approval_enabled: bool = True magentic_plan: Optional[MPlan] = None - current_user_id: Optional[str] = None + current_user_id: str + + def __init__(self, user_id: str, *args, **kwargs): + """ + Initialize the HumanApprovalMagenticManager. + Args: + user_id: ID of the user to associate with this orchestration instance. + *args: Additional positional arguments for the parent StandardMagenticManager. + **kwargs: Additional keyword arguments for the parent StandardMagenticManager. + """ - def __init__(self, *args, **kwargs): # Remove any custom kwargs before passing to parent plan_append = """ @@ -76,6 +85,8 @@ def __init__(self, *args, **kwargs): ) kwargs["final_answer_prompt"] = ORCHESTRATOR_FINAL_ANSWER_PROMPT + final_append + kwargs['current_user_id'] = user_id + super().__init__(*args, **kwargs) async def plan(self, magentic_context: MagenticContext) -> Any: @@ -100,7 +111,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any: self.magentic_plan = self.plan_to_obj(magentic_context, self.task_ledger) - self.magentic_plan.user_id = current_user_id.get() + self.magentic_plan.user_id = self.current_user_id # Request approval from the user before executing the plan approval_message = messages.PlanApprovalRequest( @@ -124,7 +135,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any: # The user_id will be automatically retrieved from context await connection_config.send_status_update_async( message=approval_message, - user_id=current_user_id.get(), + user_id=self.current_user_id, message_type=messages.WebsocketMessageType.PLAN_APPROVAL_REQUEST, ) @@ -141,7 +152,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any: "type": messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE, "data": approval_response, }, - user_id=current_user_id.get(), + user_id=self.current_user_id, message_type=messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE, ) raise Exception("Plan execution cancelled by user") @@ -170,7 +181,7 @@ async def create_progress_ledger( await connection_config.send_status_update_async( message=final_message, - user_id=current_user_id.get(), + user_id=self.current_user_id, message_type=messages.WebsocketMessageType.FINAL_RESULT_MESSAGE, ) diff --git a/src/backend/v3/orchestration/orchestration_manager.py b/src/backend/v3/orchestration/orchestration_manager.py index c62452e7..e3412667 100644 --- a/src/backend/v3/orchestration/orchestration_manager.py +++ b/src/backend/v3/orchestration/orchestration_manager.py @@ -1,10 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. """Orchestration manager to handle the orchestration logic.""" import asyncio -import contextvars import logging import uuid -from contextvars import ContextVar from typing import List, Optional from azure.identity import DefaultAzureCredential as SyncDefaultAzureCredential @@ -15,27 +13,16 @@ # Create custom execution settings to fix schema issues from semantic_kernel.connectors.ai.open_ai import ( - AzureChatCompletion, - OpenAIChatPromptExecutionSettings, -) -from semantic_kernel.contents import ChatMessageContent, StreamingChatMessageContent -from v3.callbacks.response_handlers import ( - agent_response_callback, - streaming_agent_response_callback, -) -from v3.config.settings import ( - connection_config, - orchestration_config, -) + AzureChatCompletion, OpenAIChatPromptExecutionSettings) +from semantic_kernel.contents import (ChatMessageContent, + StreamingChatMessageContent) +from v3.callbacks.response_handlers import (agent_response_callback, + streaming_agent_response_callback) +from v3.config.settings import connection_config, orchestration_config from v3.magentic_agents.magentic_agent_factory import MagenticAgentFactory from v3.models.messages import WebsocketMessageType from v3.orchestration.human_approval_manager import HumanApprovalMagenticManager -# Context variable to hold the current user ID -current_user_id: ContextVar[Optional[str]] = contextvars.ContextVar( - "current_user_id", default=None -) - class OrchestrationManager: """Manager for handling orchestration logic.""" @@ -69,6 +56,7 @@ def get_token(): magentic_orchestration = MagenticOrchestration( members=agents, manager=HumanApprovalMagenticManager( + user_id=user_id, chat_completion_service=AzureChatCompletion( deployment_name=config.AZURE_OPENAI_DEPLOYMENT_NAME, endpoint=config.AZURE_OPENAI_ENDPOINT, @@ -122,7 +110,7 @@ async def get_current_or_new_orchestration( except Exception as e: cls.logger.error("Error closing agent: %s", e) factory = MagenticAgentFactory() - agents = await factory.get_agents(team_config_input=team_config) + agents = await factory.get_agents(user_id=user_id, team_config_input=team_config) orchestration_config.orchestrations[user_id] = await cls.init_orchestration( agents, user_id ) @@ -130,7 +118,6 @@ async def get_current_or_new_orchestration( async def run_orchestration(self, user_id, input_task) -> None: """Run the orchestration with user input loop.""" - token = current_user_id.set(user_id) job_id = str(uuid.uuid4()) orchestration_config.approvals[job_id] = None @@ -190,4 +177,3 @@ async def run_orchestration(self, user_id, input_task) -> None: self.logger.error(f"Unexpected error: {e}") finally: await runtime.stop_when_idle() - current_user_id.reset(token)