Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions src/backend/v3/api/router.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import contextvars
import json
import logging
import uuid
Expand All @@ -19,8 +18,8 @@
from semantic_kernel.agents.runtime import InProcessRuntime
from v3.common.services.plan_service import PlanService
from v3.common.services.team_service import TeamService
from v3.config.settings import (connection_config, current_user_id,
orchestration_config, team_config)
from v3.config.settings import (connection_config, orchestration_config,
team_config)
from v3.orchestration.orchestration_manager import OrchestrationManager

router = APIRouter()
Expand All @@ -41,8 +40,6 @@ async def start_comms(websocket: WebSocket, process_id: str, user_id: str = Quer

user_id = user_id or "00000000-0000-0000-0000-000000000000"

current_user_id.set(user_id)

# Add to the connection manager for backend updates
connection_config.add_connection(
process_id=process_id, connection=websocket, user_id=user_id
Expand Down Expand Up @@ -74,7 +71,7 @@ async def start_comms(websocket: WebSocket, process_id: str, user_id: str = Quer
logging.error(f"Error in WebSocket connection: {str(e)}")
finally:
# Always clean up the connection
await connection_config.close_connection(user_id)
await connection_config.close_connection(process_id=process_id)


@app_v3.get("/init_team")
Expand Down Expand Up @@ -288,18 +285,14 @@ async def process_request(
raise HTTPException(status_code=500, detail="Failed to create plan")

try:
current_user_id.set(user_id) # Set context
current_context = contextvars.copy_context() # Capture context
# background_tasks.add_task(
# lambda: current_context.run(lambda:OrchestrationManager().run_orchestration, user_id, input_task)
# )

async def run_with_context():
return await current_context.run(
OrchestrationManager().run_orchestration, user_id, input_task
)
async def run_orchestration_task():
await OrchestrationManager().run_orchestration(user_id, input_task)

background_tasks.add_task(run_with_context)
background_tasks.add_task(run_orchestration_task)

return {
"status": "Request started successfully",
Expand Down
2 changes: 1 addition & 1 deletion src/backend/v3/callbacks/response_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from semantic_kernel.contents import (ChatMessageContent,
StreamingChatMessageContent)
from v3.config.settings import connection_config, current_user_id
from v3.config.settings import connection_config
from v3.models.messages import (AgentMessage, AgentMessageStreaming,
AgentToolCall, AgentToolMessage, WebsocketMessageType)

Expand Down
11 changes: 1 addition & 10 deletions src/backend/v3/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

import asyncio
import contextvars
import json
import logging
from typing import Dict, Optional
Expand All @@ -19,11 +18,6 @@

logger = logging.getLogger(__name__)

# Create a context variable to track current user
current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"current_user_id", default=None
)


class AzureConfig:
"""Azure OpenAI and authentication configuration."""
Expand Down Expand Up @@ -177,13 +171,10 @@ async def close_connection(self, process_id):
async def send_status_update_async(
self,
message: any,
user_id: Optional[str] = None,
user_id: str,
message_type: WebsocketMessageType = WebsocketMessageType.SYSTEM_MESSAGE,
):
"""Send a status update to a specific client."""
# If no process_id provided, get from context
if user_id is None:
user_id = current_user_id.get()

if not user_id:
logger.warning("No user_id available for WebSocket message")
Expand Down
14 changes: 7 additions & 7 deletions src/backend/v3/magentic_agents/magentic_agent_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from common.config.app_config import config
from common.models.messages_kernel import TeamConfiguration
from v3.config.settings import current_user_id
from v3.magentic_agents.foundry_agent import FoundryAgentTemplate
from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig
# from v3.magentic_agents.models.agent_models import (BingConfig, MCPConfig,
Expand Down Expand Up @@ -42,12 +41,13 @@ def __init__(self):
# with open(file_path, 'r') as f:
# data = json.load(f)
# return json.loads(json.dumps(data), object_hook=lambda d: SimpleNamespace(**d))
async def create_agent_from_config(self, agent_obj: SimpleNamespace) -> Union[FoundryAgentTemplate, ReasoningAgentTemplate, ProxyAgent]:

async def create_agent_from_config(self, user_id: str, agent_obj: SimpleNamespace) -> Union[FoundryAgentTemplate, ReasoningAgentTemplate, ProxyAgent]:
"""
Create an agent from configuration object.

Args:
user_id: User ID
agent_obj: Agent object from parsed JSON (SimpleNamespace)
team_model: Model name to determine which template to use

Expand All @@ -63,7 +63,6 @@ async def create_agent_from_config(self, agent_obj: SimpleNamespace) -> Union[Fo

if not deployment_name and agent_obj.name.lower() == "proxyagent":
self.logger.info("Creating ProxyAgent")
user_id = current_user_id.get()
return ProxyAgent(user_id=user_id)

# Validate supported models
Expand Down Expand Up @@ -124,11 +123,12 @@ async def create_agent_from_config(self, agent_obj: SimpleNamespace) -> Union[Fo
self.logger.info(f"Successfully created and initialized agent '{agent_obj.name}'")
return agent

async def get_agents(self, team_config_input: TeamConfiguration) -> List:
async def get_agents(self, user_id: str, team_config_input: TeamConfiguration) -> List:
"""
Create and return a team of agents from JSON configuration.

Args:
user_id: User ID
team_config_input: team configuration object from cosmos db

Returns:
Expand All @@ -143,8 +143,8 @@ async def get_agents(self, team_config_input: TeamConfiguration) -> List:
for i, agent_cfg in enumerate(team_config_input.agents, 1):
try:
self.logger.info(f"Creating agent {i}/{len(team_config_input.agents)}: {agent_cfg.name}")
agent = await self.create_agent_from_config(agent_cfg)

agent = await self.create_agent_from_config(user_id, agent_cfg)
initalized_agents.append(agent)
self._agent_list.append(agent) # Keep track for cleanup

Expand Down
19 changes: 9 additions & 10 deletions src/backend/v3/magentic_agents/proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from typing_extensions import override
from v3.callbacks.response_handlers import (agent_response_callback,
streaming_agent_response_callback)
from v3.config.settings import (connection_config, current_user_id,
orchestration_config)
from v3.config.settings import connection_config, orchestration_config
from v3.models.messages import (UserClarificationRequest,
UserClarificationResponse, WebsocketMessageType)

Expand Down Expand Up @@ -94,11 +93,11 @@ class ProxyAgent(Agent):
"""Simple proxy agent that prompts for human clarification."""

# Declare as Pydantic field
user_id: Optional[str] = Field(default=None, description="User ID for WebSocket messaging")
user_id: str = Field(default=None, description="User ID for WebSocket messaging")

def __init__(self, user_id: str = None, **kwargs):
# Get user_id from parameter or context, fallback to empty string
effective_user_id = user_id or current_user_id.get() or ""
def __init__(self, user_id: str, **kwargs):
# Get user_id from parameter, fallback to empty string
effective_user_id = user_id or ""
super().__init__(
name="ProxyAgent",
description="Call this agent when you need to clarify requests by asking the human user for more information. Ask it for more details about any unclear requirements, missing information, or if you need the user to elaborate on any aspect of the task.",
Expand All @@ -119,15 +118,15 @@ def _create_message_content(self, content: str, thread_id: str = None) -> ChatMe
async def _trigger_response_callbacks(self, message_content: ChatMessageContent):
"""Manually trigger the same response callbacks used by other agents."""
# Get current user_id dynamically instead of using stored value
current_user = current_user_id.get() or self.user_id or ""
current_user = self.user_id or ""

# Trigger the standard agent response callback
agent_response_callback(message_content, current_user)

async def _trigger_streaming_callbacks(self, content: str, is_final: bool = False):
"""Manually trigger streaming callbacks for real-time updates."""
# Get current user_id dynamically instead of using stored value
current_user = current_user_id.get() or self.user_id or ""
current_user = self.user_id or ""
streaming_message = StreamingChatMessageContent(
role=AuthorRole.ASSISTANT,
content=content,
Expand Down Expand Up @@ -158,7 +157,7 @@ async def invoke(self, message: str,*, thread: AgentThread | None = None,**kwarg
await connection_config.send_status_update_async({
"type": WebsocketMessageType.USER_CLARIFICATION_REQUEST,
"data": clarification_message
}, user_id=current_user_id.get(), message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST)
}, user_id=self.user_id, message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST)

# Get human input
human_response = await self._wait_for_user_clarification(clarification_message.request_id)
Expand Down Expand Up @@ -206,7 +205,7 @@ async def invoke_stream(self, messages, thread=None, **kwargs) -> AsyncIterator[
await connection_config.send_status_update_async({
"type": WebsocketMessageType.USER_CLARIFICATION_REQUEST,
"data": clarification_message
}, user_id=current_user_id.get(), message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST)
}, user_id=self.user_id, message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST)

# Get human input - replace with websocket call when available
human_response = await self._wait_for_user_clarification(clarification_message.request_id)
Expand Down
25 changes: 17 additions & 8 deletions src/backend/v3/orchestration/human_approval_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
ORCHESTRATOR_FINAL_ANSWER_PROMPT, ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT,
ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT)
from semantic_kernel.contents import ChatMessageContent
from v3.config.settings import (connection_config, current_user_id,
orchestration_config)
from v3.config.settings import connection_config, orchestration_config
from v3.models.models import MPlan, MStep
from v3.orchestration.helper.plan_to_mplan_converter import \
PlanToMPlanConverter
Expand All @@ -35,9 +34,17 @@ class HumanApprovalMagenticManager(StandardMagenticManager):
# Define Pydantic fields to avoid validation errors
approval_enabled: bool = True
magentic_plan: Optional[MPlan] = None
current_user_id: Optional[str] = None
current_user_id: str

def __init__(self, user_id: str, *args, **kwargs):
"""
Initialize the HumanApprovalMagenticManager.
Args:
user_id: ID of the user to associate with this orchestration instance.
*args: Additional positional arguments for the parent StandardMagenticManager.
**kwargs: Additional keyword arguments for the parent StandardMagenticManager.
"""

def __init__(self, *args, **kwargs):
# Remove any custom kwargs before passing to parent

plan_append = """
Expand Down Expand Up @@ -70,6 +77,8 @@ def __init__(self, *args, **kwargs):
kwargs['task_ledger_plan_update_prompt'] = ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT + plan_append
kwargs['final_answer_prompt'] = ORCHESTRATOR_FINAL_ANSWER_PROMPT + final_append

kwargs['current_user_id'] = user_id

super().__init__(*args, **kwargs)

async def plan(self, magentic_context: MagenticContext) -> Any:
Expand All @@ -94,7 +103,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any:

self.magentic_plan = self.plan_to_obj( magentic_context, self.task_ledger)

self.magentic_plan.user_id = current_user_id.get()
self.magentic_plan.user_id = self.current_user_id

# Request approval from the user before executing the plan
approval_message = messages.PlanApprovalRequest(
Expand All @@ -115,7 +124,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any:
# The user_id will be automatically retrieved from context
await connection_config.send_status_update_async(
message=approval_message,
user_id=current_user_id.get(),
user_id=self.current_user_id,
message_type=messages.WebsocketMessageType.PLAN_APPROVAL_REQUEST)

# Wait for user approval
Expand All @@ -129,7 +138,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any:
await connection_config.send_status_update_async({
"type": messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE,
"data": approval_response
}, user_id=current_user_id.get(), message_type=messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE)
}, user_id=self.current_user_id, message_type=messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE)
raise Exception("Plan execution cancelled by user")

async def replan(self,magentic_context: MagenticContext) -> Any:
Expand All @@ -154,7 +163,7 @@ async def create_progress_ledger(self, magentic_context: MagenticContext) -> Pro

await connection_config.send_status_update_async(
message= final_message,
user_id=current_user_id.get(),
user_id=self.current_user_id,
message_type=messages.WebsocketMessageType.FINAL_RESULT_MESSAGE)

return ProgressLedger(
Expand Down
13 changes: 3 additions & 10 deletions src/backend/v3/orchestration/orchestration_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
""" Orchestration manager to handle the orchestration logic. """

import asyncio
import contextvars
import logging
import os
import uuid
from contextvars import ContextVar
from typing import List, Optional

from azure.identity import DefaultAzureCredential as SyncDefaultAzureCredential
Expand All @@ -21,16 +19,12 @@
StreamingChatMessageContent)
from v3.callbacks.response_handlers import (agent_response_callback,
streaming_agent_response_callback)
from v3.config.settings import (config, connection_config, current_user_id,
orchestration_config)
from v3.config.settings import config, connection_config, orchestration_config
from v3.magentic_agents.magentic_agent_factory import MagenticAgentFactory
from v3.models.messages import WebsocketMessageType
from v3.orchestration.human_approval_manager import \
HumanApprovalMagenticManager

# Context variable to hold the current user ID
current_user_id: ContextVar[Optional[str]] = contextvars.ContextVar("current_user_id", default=None)

class OrchestrationManager:
"""Manager for handling orchestration logic."""

Expand Down Expand Up @@ -62,6 +56,7 @@ def get_token():
magentic_orchestration = MagenticOrchestration(
members=agents,
manager=HumanApprovalMagenticManager(
user_id=user_id,
chat_completion_service=AzureChatCompletion(
deployment_name=config.AZURE_OPENAI_DEPLOYMENT_NAME,
endpoint=config.AZURE_OPENAI_ENDPOINT,
Expand Down Expand Up @@ -101,13 +96,12 @@ async def get_current_or_new_orchestration(cls, user_id: str, team_config: TeamC
except Exception as e:
cls.logger.error("Error closing agent: %s", e)
factory = MagenticAgentFactory()
agents = await factory.get_agents(team_config_input=team_config)
agents = await factory.get_agents(user_id=user_id, team_config_input=team_config)
orchestration_config.orchestrations[user_id] = await cls.init_orchestration(agents, user_id)
return orchestration_config.get_current_orchestration(user_id)

async def run_orchestration(self, user_id, input_task) -> None:
""" Run the orchestration with user input loop."""
token = current_user_id.set(user_id)

job_id = str(uuid.uuid4())
orchestration_config.approvals[job_id] = None
Expand Down Expand Up @@ -161,5 +155,4 @@ async def run_orchestration(self, user_id, input_task) -> None:
self.logger.error(f"Unexpected error: {e}")
finally:
await runtime.stop_when_idle()
current_user_id.reset(token)

Loading