diff --git a/src/backend/app_kernel.py b/src/backend/app_kernel.py index af13066d3..8a1afebe8 100644 --- a/src/backend/app_kernel.py +++ b/src/backend/app_kernel.py @@ -1,23 +1,21 @@ # app_kernel.py -import asyncio import logging -import os -# Azure monitoring -import re -import uuid -from typing import Dict, List, Optional from azure.monitor.opentelemetry import configure_azure_monitor from common.config.app_config import config from common.models.messages_kernel import UserLanguage + # FastAPI imports -from fastapi import FastAPI, Query, Request +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware + # Local imports from middleware.health_check import HealthCheckMiddleware from v3.api.router import app_v3 + +# Azure monitoring + # Semantic Kernel imports -from v3.orchestration.orchestration_manager import OrchestrationManager # Check if the Application Insights Instrumentation Key is set in the environment variables connection_string = config.APPLICATIONINSIGHTS_CONNECTION_STRING @@ -104,4 +102,11 @@ async def user_browser_language_endpoint(user_language: UserLanguage, request: R if __name__ == "__main__": import uvicorn - uvicorn.run("app_kernel:app", host="127.0.0.1", port=8000, reload=True, log_level="info", access_log=False) + uvicorn.run( + "app_kernel:app", + host="127.0.0.1", + port=8000, + reload=True, + log_level="info", + access_log=False, + ) diff --git a/src/backend/common/config/app_config.py b/src/backend/common/config/app_config.py index 79be7ee30..f359b89e3 100644 --- a/src/backend/common/config/app_config.py +++ b/src/backend/common/config/app_config.py @@ -76,15 +76,23 @@ def __init__(self): # Optional MCP server endpoint (for local MCP server or remote) # Example: http://127.0.0.1:8000/mcp self.MCP_SERVER_ENDPOINT = self._get_optional("MCP_SERVER_ENDPOINT") - self.MCP_SERVER_NAME = self._get_optional("MCP_SERVER_NAME", "MCPGreetingServer") - self.MCP_SERVER_DESCRIPTION = self._get_optional("MCP_SERVER_DESCRIPTION", "MCP server with greeting and planning tools") + self.MCP_SERVER_NAME = self._get_optional( + "MCP_SERVER_NAME", "MCPGreetingServer" + ) + self.MCP_SERVER_DESCRIPTION = self._get_optional( + "MCP_SERVER_DESCRIPTION", "MCP server with greeting and planning tools" + ) self.TENANT_ID = self._get_optional("AZURE_TENANT_ID") self.CLIENT_ID = self._get_optional("AZURE_CLIENT_ID") - self.AZURE_AI_SEARCH_CONNECTION_NAME = self._get_optional("AZURE_AI_SEARCH_CONNECTION_NAME") - self.AZURE_AI_SEARCH_INDEX_NAME = self._get_optional("AZURE_AI_SEARCH_INDEX_NAME") + self.AZURE_AI_SEARCH_CONNECTION_NAME = self._get_optional( + "AZURE_AI_SEARCH_CONNECTION_NAME" + ) + self.AZURE_AI_SEARCH_INDEX_NAME = self._get_optional( + "AZURE_AI_SEARCH_INDEX_NAME" + ) self.AZURE_AI_SEARCH_ENDPOINT = self._get_optional("AZURE_AI_SEARCH_ENDPOINT") self.AZURE_AI_SEARCH_API_KEY = self._get_optional("AZURE_AI_SEARCH_API_KEY") - # self.BING_CONNECTION_NAME = self._get_optional("BING_CONNECTION_NAME") + # self.BING_CONNECTION_NAME = self._get_optional("BING_CONNECTION_NAME") test_team_json = self._get_optional("TEST_TEAM_JSON") @@ -117,7 +125,7 @@ def get_azure_credential(self, client_id=None): ) # CodeQL [SM05139] Okay use of DefaultAzureCredential as it is only used in development else: return ManagedIdentityCredential(client_id=client_id) - + def get_azure_credentials(self): """Retrieve Azure credentials, either from environment variables or managed identity.""" if self._azure_credentials is None: @@ -192,7 +200,8 @@ def get_cosmos_database_client(self): try: if self._cosmos_client is None: self._cosmos_client = CosmosClient( - self.COSMOSDB_ENDPOINT, credential=self.get_azure_credential(self.AZURE_CLIENT_ID) + self.COSMOSDB_ENDPOINT, + credential=self.get_azure_credential(self.AZURE_CLIENT_ID), ) if self._cosmos_database is None: diff --git a/src/backend/common/database/cosmosdb.py b/src/backend/common/database/cosmosdb.py index 90f5a66e4..c6509d0ae 100644 --- a/src/backend/common/database/cosmosdb.py +++ b/src/backend/common/database/cosmosdb.py @@ -1,37 +1,24 @@ """CosmosDB implementation of the database interface.""" -import json -import logging -import uuid - import datetime +import logging from typing import Any, Dict, List, Optional, Type -from azure.cosmos import PartitionKey, exceptions +import v3.models.messages as messages from azure.cosmos.aio import CosmosClient from azure.cosmos.aio._database import DatabaseProxy -from azure.cosmos.exceptions import CosmosResourceExistsError -import v3.models.messages as messages -from common.models.messages_kernel import ( - AgentMessage, - Plan, - Step, - TeamConfiguration, -) -from common.utils.utils_date import DateTimeEncoder - -from .database_base import DatabaseBase from ..models.messages_kernel import ( + AgentMessage, AgentMessageData, BaseDataModel, + DataType, Plan, Step, - AgentMessage, TeamConfiguration, - DataType, UserCurrentTeam, ) +from .database_base import DatabaseBase class CosmosDBClient(DatabaseBase): @@ -189,7 +176,6 @@ async def delete_item(self, item_id: str, partition_key: str) -> None: self.logger.error("Failed to delete item from CosmosDB: %s", str(e)) raise - # Plan Operations async def add_plan(self, plan: Plan) -> None: """Add a plan to CosmosDB.""" @@ -199,7 +185,6 @@ async def update_plan(self, plan: Plan) -> None: """Update a plan in CosmosDB.""" await self.update_item(plan) - async def get_plan_by_plan_id(self, plan_id: str) -> Optional[Plan]: """Retrieve a plan by plan_id.""" query = "SELECT * FROM c WHERE c.id=@plan_id AND c.data_type=@data_type" @@ -234,8 +219,9 @@ async def get_all_plans_by_team_id(self, team_id: str) -> List[Plan]: ] return await self.query_items(query, parameters, Plan) - - async def get_all_plans_by_team_id_status(self, user_id: str,team_id: str, status: str) -> List[Plan]: + async def get_all_plans_by_team_id_status( + self, user_id: str, team_id: str, status: str + ) -> List[Plan]: """Retrieve all plans for a specific team.""" query = "SELECT * FROM c WHERE c.team_id=@team_id AND c.data_type=@data_type and c.user_id=@user_id and c.overall_status=@status ORDER BY c._ts DESC" parameters = [ @@ -245,6 +231,7 @@ async def get_all_plans_by_team_id_status(self, user_id: str,team_id: str, statu {"name": "@status", "value": status}, ] return await self.query_items(query, parameters, Plan) + # Step Operations async def add_step(self, step: Step) -> None: """Add a step to CosmosDB.""" @@ -414,8 +401,6 @@ async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]: teams = await self.query_items(query, parameters, UserCurrentTeam) return teams[0] if teams else None - - async def delete_current_team(self, user_id: str) -> bool: """Delete the current team for a user.""" query = "SELECT c.id, c.session_id FROM c WHERE c.user_id=@user_id AND c.data_type=@data_type" @@ -429,9 +414,13 @@ async def delete_current_team(self, user_id: str) -> bool: if items: async for doc in items: try: - await self.container.delete_item(doc["id"], partition_key=doc["session_id"]) + await self.container.delete_item( + doc["id"], partition_key=doc["session_id"] + ) except Exception as e: - self.logger.warning("Failed deleting current team doc %s: %s", doc.get("id"), e) + self.logger.warning( + "Failed deleting current team doc %s: %s", doc.get("id"), e + ) return True @@ -457,9 +446,13 @@ async def delete_plan_by_plan_id(self, plan_id: str) -> bool: if items: async for doc in items: try: - await self.container.delete_item(doc["id"], partition_key=doc["session_id"]) + await self.container.delete_item( + doc["id"], partition_key=doc["session_id"] + ) except Exception as e: - self.logger.warning("Failed deleting current team doc %s: %s", doc.get("id"), e) + self.logger.warning( + "Failed deleting current team doc %s: %s", doc.get("id"), e + ) return True @@ -471,7 +464,6 @@ async def update_mplan(self, mplan: messages.MPlan) -> None: """Update a team configuration in the database.""" await self.update_item(mplan) - async def get_mplan(self, plan_id: str) -> Optional[messages.MPlan]: """Retrieve a mplan configuration by mplan_id.""" query = "SELECT * FROM c WHERE c.plan_id=@plan_id AND c.data_type=@data_type" @@ -481,7 +473,6 @@ async def get_mplan(self, plan_id: str) -> Optional[messages.MPlan]: ] results = await self.query_items(query, parameters, messages.MPlan) return results[0] if results else None - async def add_agent_message(self, message: AgentMessageData) -> None: """Add an agent message to the database.""" @@ -499,4 +490,4 @@ async def get_agent_messages(self, plan_id: str) -> List[AgentMessageData]: {"name": "@data_type", "value": DataType.m_plan_message}, ] - return await self.query_items(query, parameters, AgentMessageData) \ No newline at end of file + return await self.query_items(query, parameters, AgentMessageData) diff --git a/src/backend/common/database/database_base.py b/src/backend/common/database/database_base.py index 24327ee67..fe67c556c 100644 --- a/src/backend/common/database/database_base.py +++ b/src/backend/common/database/database_base.py @@ -2,7 +2,9 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Type + import v3.models.messages as messages + from ..models.messages_kernel import ( AgentMessageData, BaseDataModel, @@ -19,30 +21,25 @@ class DatabaseBase(ABC): @abstractmethod async def initialize(self) -> None: """Initialize the database client and create containers if needed.""" - pass @abstractmethod async def close(self) -> None: """Close database connection.""" - pass # Core CRUD Operations @abstractmethod async def add_item(self, item: BaseDataModel) -> None: """Add an item to the database.""" - pass @abstractmethod async def update_item(self, item: BaseDataModel) -> None: """Update an item in the database.""" - pass @abstractmethod async def get_item_by_id( self, item_id: str, partition_key: str, model_class: Type[BaseDataModel] ) -> Optional[BaseDataModel]: """Retrieve an item by its ID and partition key.""" - pass @abstractmethod async def query_items( @@ -52,116 +49,92 @@ async def query_items( model_class: Type[BaseDataModel], ) -> List[BaseDataModel]: """Query items from the database and return a list of model instances.""" - pass @abstractmethod async def delete_item(self, item_id: str, partition_key: str) -> None: """Delete an item from the database.""" - pass - # Plan Operations @abstractmethod async def add_plan(self, plan: Plan) -> None: """Add a plan to the database.""" - pass @abstractmethod async def update_plan(self, plan: Plan) -> None: """Update a plan in the database.""" - pass @abstractmethod async def get_plan_by_plan_id(self, plan_id: str) -> Optional[Plan]: """Retrieve a plan by plan_id.""" - pass @abstractmethod async def get_plan(self, plan_id: str) -> Optional[Plan]: """Retrieve a plan by plan_id.""" - pass @abstractmethod async def get_all_plans(self) -> List[Plan]: """Retrieve all plans for the user.""" - pass @abstractmethod async def get_all_plans_by_team_id(self, team_id: str) -> List[Plan]: """Retrieve all plans for a specific team.""" - pass @abstractmethod async def get_all_plans_by_team_id_status( - self, user_id:str, team_id: str, status: str + self, user_id: str, team_id: str, status: str ) -> List[Plan]: """Retrieve all plans for a specific team.""" - pass - - # Step Operations @abstractmethod async def add_step(self, step: Step) -> None: """Add a step to the database.""" - pass @abstractmethod async def update_step(self, step: Step) -> None: """Update a step in the database.""" - pass @abstractmethod async def get_steps_by_plan(self, plan_id: str) -> List[Step]: """Retrieve all steps for a plan.""" - pass @abstractmethod async def get_step(self, step_id: str, session_id: str) -> Optional[Step]: """Retrieve a step by step_id and session_id.""" - pass # Team Operations @abstractmethod async def add_team(self, team: TeamConfiguration) -> None: """Add a team configuration to the database.""" - pass @abstractmethod async def update_team(self, team: TeamConfiguration) -> None: """Update a team configuration in the database.""" - pass @abstractmethod async def get_team(self, team_id: str) -> Optional[TeamConfiguration]: """Retrieve a team configuration by team_id.""" - pass @abstractmethod async def get_team_by_id(self, team_id: str) -> Optional[TeamConfiguration]: """Retrieve a team configuration by internal id.""" - pass @abstractmethod async def get_all_teams(self) -> List[TeamConfiguration]: """Retrieve all team configurations for the given user.""" - pass @abstractmethod async def delete_team(self, team_id: str) -> bool: """Delete a team configuration by team_id and return True if deleted.""" - pass # Data Management Operations @abstractmethod async def get_data_by_type(self, data_type: str) -> List[BaseDataModel]: """Retrieve all data of a specific type.""" - pass @abstractmethod async def get_all_items(self) -> List[Dict[str, Any]]: """Retrieve all items as dictionaries.""" - pass # Context Manager Support async def __aenter__(self): @@ -176,17 +149,14 @@ async def __aexit__(self, exc_type, exc, tb): @abstractmethod async def get_steps_for_plan(self, plan_id: str) -> List[Step]: """Convenience method aliasing get_steps_by_plan for compatibility.""" - pass @abstractmethod async def get_current_team(self, user_id: str) -> Optional[UserCurrentTeam]: """Retrieve the current team for a user.""" - pass @abstractmethod async def delete_current_team(self, user_id: str) -> Optional[UserCurrentTeam]: """Retrieve the current team for a user.""" - pass @abstractmethod async def set_current_team(self, current_team: UserCurrentTeam) -> None: @@ -195,28 +165,23 @@ async def set_current_team(self, current_team: UserCurrentTeam) -> None: @abstractmethod async def update_current_team(self, current_team: UserCurrentTeam) -> None: """Update the current team for a user.""" - pass - + @abstractmethod async def delete_plan_by_plan_id(self, plan_id: str) -> bool: """Retrieve the current team for a user.""" - pass @abstractmethod async def add_mplan(self, mplan: messages.MPlan) -> None: """Add a team configuration to the database.""" - pass @abstractmethod async def update_mplan(self, mplan: messages.MPlan) -> None: """Update a team configuration in the database.""" - pass @abstractmethod async def get_mplan(self, plan_id: str) -> Optional[messages.MPlan]: """Retrieve a mplan configuration by plan_id.""" - pass - + @abstractmethod async def add_agent_message(self, message: AgentMessageData) -> None: pass @@ -224,9 +189,7 @@ async def add_agent_message(self, message: AgentMessageData) -> None: @abstractmethod async def update_agent_message(self, message: AgentMessageData) -> None: """Update an agent message in the database.""" - pass @abstractmethod async def get_agent_messages(self, plan_id: str) -> Optional[AgentMessageData]: """Retrieve an agent message by message_id.""" - pass \ No newline at end of file diff --git a/src/backend/common/models/messages_kernel.py b/src/backend/common/models/messages_kernel.py index c57b12e94..dccc5b3b3 100644 --- a/src/backend/common/models/messages_kernel.py +++ b/src/backend/common/models/messages_kernel.py @@ -5,6 +5,7 @@ from semantic_kernel.kernel_pydantic import Field, KernelBaseModel + class DataType(str, Enum): """Enumeration of possible data types for documents in the database.""" @@ -87,14 +88,15 @@ class BaseDataModel(KernelBaseModel): class AgentMessage(BaseDataModel): """Base class for messages sent between agents.""" - data_type: Literal[DataType.agent_message] = Field(DataType.agent_message, Literal=True) + data_type: Literal[DataType.agent_message] = Field( + DataType.agent_message, Literal=True + ) plan_id: str content: str source: str step_id: Optional[str] = None - class Session(BaseDataModel): """Represents a user session.""" @@ -107,7 +109,9 @@ class Session(BaseDataModel): class UserCurrentTeam(BaseDataModel): """Represents the current team of a user.""" - data_type: Literal[DataType.user_current_team] = Field(DataType.user_current_team, Literal=True) + data_type: Literal[DataType.user_current_team] = Field( + DataType.user_current_team, Literal=True + ) user_id: str team_id: str @@ -235,13 +239,11 @@ def update_step_counts(self): self.completed = status_counts[StepStatus.completed] self.failed = status_counts[StepStatus.failed] - if self.total_steps > 0 and (self.completed + self.failed) == self.total_steps: self.overall_status = PlanStatus.completed # Mark the plan as complete if the sum of completed and failed steps equals the total number of steps - # Message classes for communication between agents class InputTask(KernelBaseModel): """Message representing the initial input task from the user.""" @@ -260,15 +262,17 @@ class AgentMessageType(str, Enum): AI_AGENT = "AI_Agent", -class AgentMessageData (BaseDataModel): +class AgentMessageData(BaseDataModel): - data_type: Literal[DataType.m_plan_message] = Field(DataType.m_plan_message, Literal=True) + data_type: Literal[DataType.m_plan_message] = Field( + DataType.m_plan_message, Literal=True + ) plan_id: str user_id: str agent: str m_plan_id: Optional[str] = None - agent_type: AgentMessageType = AgentMessageType.AI_AGENT + agent_type: AgentMessageType = AgentMessageType.AI_AGENT content: str raw_data: str - steps: List[Any] = Field(default_factory=list) - next_steps: List[Any] = Field(default_factory=list) + steps: List[Any] = Field(default_factory=list) + next_steps: List[Any] = Field(default_factory=list) diff --git a/src/backend/common/utils/check_deployments.py b/src/backend/common/utils/check_deployments.py index b2db1e0bf..614c65ea4 100644 --- a/src/backend/common/utils/check_deployments.py +++ b/src/backend/common/utils/check_deployments.py @@ -1,10 +1,10 @@ import asyncio -import sys import os +import sys import traceback # Add the backend directory to the Python path -backend_path = os.path.join(os.path.dirname(__file__), '..', '..') +backend_path = os.path.join(os.path.dirname(__file__), "..", "..") sys.path.insert(0, backend_path) try: @@ -13,38 +13,38 @@ print(f"āŒ Import error: {e}") sys.exit(1) + async def check_deployments(): try: print("šŸ” Checking Azure AI Foundry model deployments...") foundry_service = FoundryService() deployments = await foundry_service.list_model_deployments() - + # Filter successful deployments successful_deployments = [ - d for d in deployments - if d.get('status') == 'Succeeded' - ] - - print(f"āœ… Total deployments: {len(deployments)} (Successful: {len(successful_deployments)})") - - available_models = [ - d.get('name', '').lower() - for d in successful_deployments + d for d in deployments if d.get("status") == "Succeeded" ] - + + print( + f"āœ… Total deployments: {len(deployments)} (Successful: {len(successful_deployments)})" + ) + + available_models = [d.get("name", "").lower() for d in successful_deployments] + # Check what we're looking for - required_models = ['gpt-4o', 'o3', 'gpt-4', 'gpt-35-turbo'] - - print(f"\nšŸ” Checking required models:") + required_models = ["gpt-4o", "o3", "gpt-4", "gpt-35-turbo"] + + print(f"\nšŸ” Checking required models: {required_models}") for model in required_models: if model.lower() in available_models: - print(f'āœ… {model} is available') + print(f"āœ… {model} is available") else: - print(f'āŒ {model} is NOT available') - + print(f"āŒ {model} is NOT available") + except Exception as e: - print(f'āŒ Error: {e}') + print(f"āŒ Error: {e}") traceback.print_exc() + if __name__ == "__main__": asyncio.run(check_deployments()) diff --git a/src/backend/common/utils/event_utils.py b/src/backend/common/utils/event_utils.py index 0e03c0757..97368f622 100644 --- a/src/backend/common/utils/event_utils.py +++ b/src/backend/common/utils/event_utils.py @@ -1,5 +1,5 @@ import logging -import os + from azure.monitor.events.extension import track_event from common.config.app_config import config diff --git a/src/backend/common/utils/utils_kernel.py b/src/backend/common/utils/utils_kernel.py index d27b0d0d9..db4a19ea2 100644 --- a/src/backend/common/utils/utils_kernel.py +++ b/src/backend/common/utils/utils_kernel.py @@ -1,4 +1,4 @@ -""" Utility functions for Semantic Kernel integration and agent management.""" +"""Utility functions for Semantic Kernel integration and agent management.""" import logging from typing import Any, Dict @@ -13,9 +13,10 @@ agent_instances: Dict[str, Dict[str, Any]] = {} azure_agent_instances: Dict[str, Dict[str, AzureAIAgent]] = {} + async def create_RAI_agent() -> FoundryAgentTemplate: """Create and initialize a FoundryAgentTemplate for RAI checks.""" - + agent_name = "RAIAgent" agent_description = "A comprehensive research assistant for integration testing" agent_instructions = ( @@ -40,25 +41,26 @@ async def create_RAI_agent() -> FoundryAgentTemplate: model_deployment_name=model_deployment_name, enable_code_interpreter=False, mcp_config=None, - #bing_config=None, - search_config=None + # bing_config=None, + search_config=None, ) await agent.open() return agent + async def _get_agent_response(agent: FoundryAgentTemplate, query: str) -> str: """Helper method to get complete response from agent.""" response_parts = [] async for message in agent.invoke(query): - if hasattr(message, 'content'): + if hasattr(message, "content"): # Handle different content types properly content = message.content - if hasattr(content, 'text'): + if hasattr(content, "text"): response_parts.append(str(content.text)) elif isinstance(content, list): for item in content: - if hasattr(item, 'text'): + if hasattr(item, "text"): response_parts.append(str(item.text)) else: response_parts.append(str(item)) @@ -66,7 +68,8 @@ async def _get_agent_response(agent: FoundryAgentTemplate, query: str) -> str: response_parts.append(str(content)) else: response_parts.append(str(message)) - return ''.join(response_parts) + return "".join(response_parts) + async def rai_success(description: str) -> bool: """ @@ -83,15 +86,13 @@ async def rai_success(description: str) -> bool: if not rai_agent: print("Failed to create RAI agent") return False - + rai_agent_response = await _get_agent_response(rai_agent, description) # AI returns "TRUE" if content violates rules (should be blocked) # AI returns "FALSE" if content is safe (should be allowed) if str(rai_agent_response).upper() == "TRUE": - logging.warning( - "RAI check failed for content: %s...", description[:50] - ) + logging.warning("RAI check failed for content: %s...", description[:50]) return False # Content should be blocked elif str(rai_agent_response).upper() == "FALSE": logging.info("RAI check passed") @@ -104,7 +105,7 @@ async def rai_success(description: str) -> bool: logging.warning("RAI check returned unexpected status, defaulting to block") return False - except Exception as e: # pylint: disable=broad-except + except Exception as e: # pylint: disable=broad-except logging.error("Error in RAI check: %s", str(e)) # Default to blocking the operation if RAI check fails for safety return False @@ -174,6 +175,6 @@ async def rai_validate_team_config(team_config_json: dict) -> tuple[bool, str]: return True, "" - except Exception as e: # pylint: disable=broad-except + except Exception as e: # pylint: disable=broad-except logging.error("Error validating team configuration with RAI: %s", str(e)) return False, "Unable to validate team configuration content. Please try again." diff --git a/src/backend/common/utils/websocket_streaming.py b/src/backend/common/utils/websocket_streaming.py index d9e656802..6a1baf519 100644 --- a/src/backend/common/utils/websocket_streaming.py +++ b/src/backend/common/utils/websocket_streaming.py @@ -3,34 +3,36 @@ This is a basic implementation that can be expanded based on your backend framework """ -from fastapi import FastAPI, WebSocket, WebSocketDisconnect -from typing import Dict, Set -import json import asyncio +import json import logging +from typing import Dict, Set + +from fastapi import WebSocket, WebSocketDisconnect logger = logging.getLogger(__name__) + class WebSocketManager: def __init__(self): self.active_connections: Dict[str, WebSocket] = {} self.plan_subscriptions: Dict[str, Set[str]] = {} # plan_id -> set of connection_ids - + async def connect(self, websocket: WebSocket, connection_id: str): await websocket.accept() self.active_connections[connection_id] = websocket logger.info(f"WebSocket connection established: {connection_id}") - + def disconnect(self, connection_id: str): if connection_id in self.active_connections: del self.active_connections[connection_id] - + # Remove from all plan subscriptions for plan_id, subscribers in self.plan_subscriptions.items(): subscribers.discard(connection_id) - + logger.info(f"WebSocket connection closed: {connection_id}") - + async def send_personal_message(self, message: dict, connection_id: str): if connection_id in self.active_connections: websocket = self.active_connections[connection_id] @@ -39,14 +41,14 @@ async def send_personal_message(self, message: dict, connection_id: str): except Exception as e: logger.error(f"Error sending message to {connection_id}: {e}") self.disconnect(connection_id) - + async def broadcast_to_plan(self, message: dict, plan_id: str): """Broadcast message to all subscribers of a specific plan""" if plan_id not in self.plan_subscriptions: return - + disconnected_connections = [] - + for connection_id in self.plan_subscriptions[plan_id].copy(): if connection_id in self.active_connections: websocket = self.active_connections[connection_id] @@ -55,67 +57,75 @@ async def broadcast_to_plan(self, message: dict, plan_id: str): except Exception as e: logger.error(f"Error broadcasting to {connection_id}: {e}") disconnected_connections.append(connection_id) - + # Clean up failed connections for connection_id in disconnected_connections: self.disconnect(connection_id) - + def subscribe_to_plan(self, connection_id: str, plan_id: str): if plan_id not in self.plan_subscriptions: self.plan_subscriptions[plan_id] = set() - + self.plan_subscriptions[plan_id].add(connection_id) logger.info(f"Connection {connection_id} subscribed to plan {plan_id}") - + def unsubscribe_from_plan(self, connection_id: str, plan_id: str): if plan_id in self.plan_subscriptions: self.plan_subscriptions[plan_id].discard(connection_id) logger.info(f"Connection {connection_id} unsubscribed from plan {plan_id}") + # Global WebSocket manager instance ws_manager = WebSocketManager() + # WebSocket endpoint async def websocket_streaming_endpoint(websocket: WebSocket): connection_id = f"conn_{id(websocket)}" await ws_manager.connect(websocket, connection_id) - + try: while True: data = await websocket.receive_text() message = json.loads(data) - + message_type = message.get("type") - + if message_type == "subscribe_plan": plan_id = message.get("plan_id") if plan_id: ws_manager.subscribe_to_plan(connection_id, plan_id) - + # Send confirmation - await ws_manager.send_personal_message({ - "type": "subscription_confirmed", - "plan_id": plan_id - }, connection_id) - + await ws_manager.send_personal_message( + {"type": "subscription_confirmed", "plan_id": plan_id}, + connection_id, + ) + elif message_type == "unsubscribe_plan": plan_id = message.get("plan_id") if plan_id: ws_manager.unsubscribe_from_plan(connection_id, plan_id) - + else: logger.warning(f"Unknown message type: {message_type}") - + except WebSocketDisconnect: ws_manager.disconnect(connection_id) except Exception as e: logger.error(f"WebSocket error: {e}") ws_manager.disconnect(connection_id) + # Example function to send plan updates (call this from your plan execution logic) -async def send_plan_update(plan_id: str, step_id: str = None, agent_name: str = None, - content: str = None, status: str = "in_progress", - message_type: str = "action"): +async def send_plan_update( + plan_id: str, + step_id: str = None, + agent_name: str = None, + content: str = None, + status: str = "in_progress", + message_type: str = "action", +): """ Send a streaming update for a specific plan """ @@ -128,15 +138,17 @@ async def send_plan_update(plan_id: str, step_id: str = None, agent_name: str = "content": content, "status": status, "message_type": message_type, - "timestamp": asyncio.get_event_loop().time() - } + "timestamp": asyncio.get_event_loop().time(), + }, } - + await ws_manager.broadcast_to_plan(message, plan_id) + # Example function to send agent messages -async def send_agent_message(plan_id: str, agent_name: str, content: str, - message_type: str = "thinking"): +async def send_agent_message( + plan_id: str, agent_name: str, content: str, message_type: str = "thinking" +): """ Send a streaming message from an agent """ @@ -147,14 +159,17 @@ async def send_agent_message(plan_id: str, agent_name: str, content: str, "agent_name": agent_name, "content": content, "message_type": message_type, - "timestamp": asyncio.get_event_loop().time() - } + "timestamp": asyncio.get_event_loop().time(), + }, } - + await ws_manager.broadcast_to_plan(message, plan_id) + # Example function to send step updates -async def send_step_update(plan_id: str, step_id: str, status: str, content: str = None): +async def send_step_update( + plan_id: str, step_id: str, status: str, content: str = None +): """ Send a streaming update for a specific step """ @@ -165,12 +180,13 @@ async def send_step_update(plan_id: str, step_id: str, status: str, content: str "step_id": step_id, "status": status, "content": content, - "timestamp": asyncio.get_event_loop().time() - } + "timestamp": asyncio.get_event_loop().time(), + }, } - + await ws_manager.broadcast_to_plan(message, plan_id) + # Example integration with FastAPI """ from fastapi import FastAPI @@ -185,20 +201,14 @@ async def websocket_endpoint(websocket: WebSocket): async def execute_plan_step(plan_id: str, step_id: str): # Send initial update await send_step_update(plan_id, step_id, "in_progress", "Starting step execution...") - # Simulate some work await asyncio.sleep(2) - # Send agent thinking message await send_agent_message(plan_id, "Data Analyst", "Analyzing the requirements...", "thinking") - await asyncio.sleep(1) - # Send agent action message await send_agent_message(plan_id, "Data Analyst", "Processing data and generating insights...", "action") - await asyncio.sleep(3) - # Send completion update await send_step_update(plan_id, step_id, "completed", "Step completed successfully!") """ diff --git a/src/backend/tests/middleware/test_health_check.py b/src/backend/tests/middleware/test_health_check.py index 52a5a985e..727692c39 100644 --- a/src/backend/tests/middleware/test_health_check.py +++ b/src/backend/tests/middleware/test_health_check.py @@ -1,10 +1,9 @@ -from src.backend.middleware.health_check import ( - HealthCheckMiddleware, - HealthCheckResult, -) +from asyncio import sleep + from fastapi import FastAPI from starlette.testclient import TestClient -from asyncio import sleep + +from src.backend.middleware.health_check import HealthCheckMiddleware, HealthCheckResult # Updated helper functions for test health checks diff --git a/src/backend/tests/test_config.py b/src/backend/tests/test_config.py index cc2d74f83..2ab1b51d1 100644 --- a/src/backend/tests/test_config.py +++ b/src/backend/tests/test_config.py @@ -14,23 +14,19 @@ "COSMOSDB_ENDPOINT": "https://mock-cosmosdb.documents.azure.com:443/", "COSMOSDB_DATABASE": "mock_database", "COSMOSDB_CONTAINER": "mock_container", - # Azure OpenAI "AZURE_OPENAI_DEPLOYMENT_NAME": "mock-deployment", "AZURE_OPENAI_API_VERSION": "2024-11-20", "AZURE_OPENAI_ENDPOINT": "https://mock-openai-endpoint.azure.com/", - # Optional auth (kept for completeness) "AZURE_TENANT_ID": "mock-tenant-id", "AZURE_CLIENT_ID": "mock-client-id", "AZURE_CLIENT_SECRET": "mock-client-secret", - # Azure AI Project (required by current AppConfig) "AZURE_AI_SUBSCRIPTION_ID": "00000000-0000-0000-0000-000000000000", "AZURE_AI_RESOURCE_GROUP": "rg-test", "AZURE_AI_PROJECT_NAME": "proj-test", "AZURE_AI_AGENT_ENDPOINT": "https://agents.example.com/", - # Misc "USER_LOCAL_BROWSER_LANGUAGE": "en-US", } @@ -44,23 +40,31 @@ def GetRequiredConfig(name: str, default=None): return app_config._get_required(name, default) + def GetOptionalConfig(name: str, default: str = ""): return app_config._get_optional(name, default) + def GetBoolConfig(name: str) -> bool: return app_config._get_bool(name) # ---- Tests (unchanged semantics) ---- + @patch.dict(os.environ, MOCK_ENV_VARS, clear=False) def test_get_required_config(): assert GetRequiredConfig("COSMOSDB_ENDPOINT") == MOCK_ENV_VARS["COSMOSDB_ENDPOINT"] + @patch.dict(os.environ, MOCK_ENV_VARS, clear=False) def test_get_optional_config(): assert GetOptionalConfig("NON_EXISTENT_VAR", "default_value") == "default_value" - assert GetOptionalConfig("COSMOSDB_DATABASE", "default_db") == MOCK_ENV_VARS["COSMOSDB_DATABASE"] + assert ( + GetOptionalConfig("COSMOSDB_DATABASE", "default_db") + == MOCK_ENV_VARS["COSMOSDB_DATABASE"] + ) + @patch.dict(os.environ, MOCK_ENV_VARS, clear=False) def test_get_bool_config(): diff --git a/src/backend/tests/test_otlp_tracing.py b/src/backend/tests/test_otlp_tracing.py index 576e2e0c2..3fd01ad90 100644 --- a/src/backend/tests/test_otlp_tracing.py +++ b/src/backend/tests/test_otlp_tracing.py @@ -1,6 +1,4 @@ -import sys import os -from unittest.mock import patch, MagicMock from src.backend.common.utils.otlp_tracing import ( configure_oltp_tracing, ) # Import directly since it's in backend diff --git a/src/backend/tests/test_team_specific_methods.py b/src/backend/tests/test_team_specific_methods.py index 7f43b3780..0e81558be 100644 --- a/src/backend/tests/test_team_specific_methods.py +++ b/src/backend/tests/test_team_specific_methods.py @@ -4,21 +4,17 @@ """ import asyncio -import uuid -from datetime import datetime, timezone +import os # Add the parent directory to the path so we can import our modules import sys -import os +import uuid +from datetime import datetime, timezone sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from common.models.messages_kernel import ( - TeamConfiguration, - TeamAgent, - StartingTask, -) +from common.models.messages_kernel import StartingTask, TeamAgent, TeamConfiguration async def test_team_specific_methods(): diff --git a/src/backend/v3/api/router.py b/src/backend/v3/api/router.py index f3e6fa436..4ed51a4a2 100644 --- a/src/backend/v3/api/router.py +++ b/src/backend/v3/api/router.py @@ -7,20 +7,34 @@ import v3.models.messages as messages from auth.auth_utils import get_authenticated_user_details -from common.config.app_config import config from common.database.database_factory import DatabaseFactory -from common.models.messages_kernel import (InputTask, Plan, PlanStatus, - PlanWithSteps, TeamSelectionRequest) +from common.models.messages_kernel import ( + InputTask, + Plan, + PlanStatus, + TeamSelectionRequest, +) from common.utils.event_utils import track_event_if_configured -from common.utils.utils_date import format_dates_in_messages from common.utils.utils_kernel import rai_success, rai_validate_team_config -from fastapi import (APIRouter, BackgroundTasks, File, HTTPException, Query, - Request, UploadFile, WebSocket, WebSocketDisconnect) -from semantic_kernel.agents.runtime import InProcessRuntime +from fastapi import ( + APIRouter, + BackgroundTasks, + File, + HTTPException, + Query, + Request, + UploadFile, + WebSocket, + WebSocketDisconnect, +) from v3.common.services.plan_service import PlanService from v3.common.services.team_service import TeamService -from v3.config.settings import (connection_config, current_user_id, - orchestration_config, team_config) +from v3.config.settings import ( + connection_config, + current_user_id, + orchestration_config, + team_config, +) from v3.orchestration.orchestration_manager import OrchestrationManager router = APIRouter() @@ -33,7 +47,9 @@ @app_v3.websocket("/socket/{process_id}") -async def start_comms(websocket: WebSocket, process_id: str, user_id: str = Query(None)): +async def start_comms( + websocket: WebSocket, process_id: str, user_id: str = Query(None) +): """Web-Socket endpoint for real-time process status updates.""" # Always accept the WebSocket connection first @@ -489,8 +505,8 @@ async def user_clarification( ) # Set the approval in the orchestration config if user_id and human_feedback.request_id: - ### validate rai - if human_feedback.answer != None or human_feedback.answer != "": + # validate rai + if human_feedback.answer is not None or human_feedback.answer != "": if not await rai_success(human_feedback.answer): track_event_if_configured( "RAI failed", @@ -790,7 +806,6 @@ async def upload_team_config( { "status": "success", "team_id": team_id, - "team_id": team_config.team_id, "user_id": user_id, "agents_count": len(team_config.agents), "tasks_count": len(team_config.starting_tasks), @@ -1204,9 +1219,9 @@ async def get_plans(request: Request): ) raise HTTPException(status_code=400, detail="no user") - #### Replace the following with code to get plan run history from the database + # Replace the following with code to get plan run history from the database - # # Initialize memory context + # Initialize memory context memory_store = await DatabaseFactory.get_database(user_id=user_id) current_team = await memory_store.get_current_team(user_id=user_id) @@ -1222,7 +1237,10 @@ async def get_plans(request: Request): # Get plans is called in the initial side rendering of the frontend @app_v3.get("/plan") -async def get_plan_by_id(request: Request, plan_id: Optional[str] = Query(None),): +async def get_plan_by_id( + request: Request, + plan_id: Optional[str] = Query(None), +): """ Retrieve plans for the current user. @@ -1289,9 +1307,9 @@ async def get_plan_by_id(request: Request, plan_id: Optional[str] = Query(None) ) raise HTTPException(status_code=400, detail="no user") - #### Replace the following with code to get plan run history from the database + # Replace the following with code to get plan run history from the database - # # Initialize memory context + # Initialize memory context memory_store = await DatabaseFactory.get_database(user_id=user_id) try: if plan_id: diff --git a/src/backend/v3/callbacks/global_debug.py b/src/backend/v3/callbacks/global_debug.py index a44bde4fe..3da87681f 100644 --- a/src/backend/v3/callbacks/global_debug.py +++ b/src/backend/v3/callbacks/global_debug.py @@ -1,7 +1,6 @@ - class DebugGlobalAccess: """Class to manage global access to the Magentic orchestration manager.""" - + _managers = [] @classmethod @@ -12,4 +11,4 @@ def add_manager(cls, manager): @classmethod def get_managers(cls): """Get the list of all managers.""" - return cls._managers \ No newline at end of file + return cls._managers diff --git a/src/backend/v3/callbacks/response_handlers.py b/src/backend/v3/callbacks/response_handlers.py index ef4c24f72..fb91dd9c0 100644 --- a/src/backend/v3/callbacks/response_handlers.py +++ b/src/backend/v3/callbacks/response_handlers.py @@ -2,17 +2,20 @@ Enhanced response callbacks for employee onboarding agent system. Provides detailed monitoring and response handling for different agent types. """ + import asyncio -import json import logging -import sys import time -from semantic_kernel.contents import (ChatMessageContent, - StreamingChatMessageContent) -from v3.config.settings import connection_config, current_user_id -from v3.models.messages import (AgentMessage, AgentMessageStreaming, - AgentToolCall, AgentToolMessage, WebsocketMessageType) +from semantic_kernel.contents import ChatMessageContent, StreamingChatMessageContent +from v3.config.settings import connection_config +from v3.models.messages import ( + AgentMessage, + AgentMessageStreaming, + AgentToolCall, + AgentToolMessage, + WebsocketMessageType, +) def agent_response_callback(message: ChatMessageContent, user_id: str = None) -> None: @@ -21,41 +24,71 @@ def agent_response_callback(message: ChatMessageContent, user_id: str = None) -> # Get agent name to determine handling agent_name = message.name or "Unknown Agent" - # Get message type - content_type = getattr(message, 'content_type', 'text') - role = getattr(message, 'role', 'unknown') + role = getattr(message, "role", "unknown") # Send to WebSocket if user_id: try: - if message.items and message.items[0].content_type == 'function_call': - final_message = AgentToolMessage(agent_name=agent_name) + if message.items and message.items[0].content_type == "function_call": + final_message = AgentToolMessage(agent_name=agent_name) for item in message.items: - if item.content_type == 'function_call': - tool_call = AgentToolCall(tool_name=item.name or "unknown_tool", arguments=item.arguments or {}) + if item.content_type == "function_call": + tool_call = AgentToolCall( + tool_name=item.name or "unknown_tool", + arguments=item.arguments or {}, + ) final_message.tool_calls.append(tool_call) - asyncio.create_task(connection_config.send_status_update_async(final_message, user_id, message_type=WebsocketMessageType.AGENT_TOOL_MESSAGE)) + asyncio.create_task( + connection_config.send_status_update_async( + final_message, + user_id, + message_type=WebsocketMessageType.AGENT_TOOL_MESSAGE, + ) + ) logging.info(f"Function call: {final_message}") - elif message.items and message.items[0].content_type == 'function_result': + elif message.items and message.items[0].content_type == "function_result": # skip returning these results for now - agent will return in a later message pass else: - final_message = AgentMessage(agent_name=agent_name, timestamp=time.time() or "", content=message.content or "") - - asyncio.create_task(connection_config.send_status_update_async(final_message, user_id, message_type=WebsocketMessageType.AGENT_MESSAGE)) + final_message = AgentMessage( + agent_name=agent_name, + timestamp=time.time() or "", + content=message.content or "", + ) + + asyncio.create_task( + connection_config.send_status_update_async( + final_message, + user_id, + message_type=WebsocketMessageType.AGENT_MESSAGE, + ) + ) logging.info(f"{role.capitalize()} message: {final_message}") except Exception as e: logging.error(f"Response_callback: Error sending WebSocket message: {e}") - -async def streaming_agent_response_callback(streaming_message: StreamingChatMessageContent, is_final: bool, user_id: str = None) -> None: + + +async def streaming_agent_response_callback( + streaming_message: StreamingChatMessageContent, is_final: bool, user_id: str = None +) -> None: """Simple streaming callback to show real-time agent responses.""" # process only content messages - if hasattr(streaming_message, 'content') and streaming_message.content: + if hasattr(streaming_message, "content") and streaming_message.content: if user_id: try: - message = AgentMessageStreaming(agent_name=streaming_message.name or "Unknown Agent", content=streaming_message.content, is_final=is_final) - await connection_config.send_status_update_async(message, user_id, message_type=WebsocketMessageType.AGENT_MESSAGE_STREAMING) + message = AgentMessageStreaming( + agent_name=streaming_message.name or "Unknown Agent", + content=streaming_message.content, + is_final=is_final, + ) + await connection_config.send_status_update_async( + message, + user_id, + message_type=WebsocketMessageType.AGENT_MESSAGE_STREAMING, + ) except Exception as e: - logging.error(f"Response_callback: Error sending streaming WebSocket message: {e}") \ No newline at end of file + logging.error( + f"Response_callback: Error sending streaming WebSocket message: {e}" + ) diff --git a/src/backend/v3/common/services/__init__.py b/src/backend/v3/common/services/__init__.py index ef16a965d..4c07712c9 100644 --- a/src/backend/v3/common/services/__init__.py +++ b/src/backend/v3/common/services/__init__.py @@ -6,10 +6,10 @@ - FoundryService: helper around Azure AI Foundry (AIProjectClient) """ +from .agents_service import AgentsService from .base_api_service import BaseAPIService -from .mcp_service import MCPService from .foundry_service import FoundryService -from .agents_service import AgentsService +from .mcp_service import MCPService __all__ = [ "BaseAPIService", diff --git a/src/backend/v3/common/services/agents_service.py b/src/backend/v3/common/services/agents_service.py index d4f233716..fc4e7fa06 100644 --- a/src/backend/v3/common/services/agents_service.py +++ b/src/backend/v3/common/services/agents_service.py @@ -9,11 +9,11 @@ agent instances. """ -from typing import Any, Dict, List, Union import logging +from typing import Any, Dict, List, Union +from common.models.messages_kernel import TeamAgent, TeamConfiguration from v3.common.services.team_service import TeamService -from common.models.messages_kernel import TeamConfiguration, TeamAgent class AgentsService: diff --git a/src/backend/v3/common/services/base_api_service.py b/src/backend/v3/common/services/base_api_service.py index 2c43fe6a0..8f8b48ef1 100644 --- a/src/backend/v3/common/services/base_api_service.py +++ b/src/backend/v3/common/services/base_api_service.py @@ -1,8 +1,6 @@ -import asyncio from typing import Any, Dict, Optional, Union import aiohttp - from common.config.app_config import config diff --git a/src/backend/v3/common/services/foundry_service.py b/src/backend/v3/common/services/foundry_service.py index 8f41020f8..563f5c56c 100644 --- a/src/backend/v3/common/services/foundry_service.py +++ b/src/backend/v3/common/services/foundry_service.py @@ -57,7 +57,6 @@ async def list_model_deployments(self) -> List[Dict[str, Any]]: credential = config.get_azure_credentials() token = credential.get_token(config.AZURE_MANAGEMENT_SCOPE) - # Extract Azure OpenAI resource name from endpoint URL openai_endpoint = config.AZURE_OPENAI_ENDPOINT # Extract resource name from URL like "https://aisa-macae-d3x6aoi7uldi.openai.azure.com/" diff --git a/src/backend/v3/common/services/plan_service.py b/src/backend/v3/common/services/plan_service.py index 1c0da50b0..ff7d5b30e 100644 --- a/src/backend/v3/common/services/plan_service.py +++ b/src/backend/v3/common/services/plan_service.py @@ -1,23 +1,17 @@ -from dataclasses import Field, asdict import json import logging -import time -from typing import Dict, Any, Optional -from common.database.database_factory import DatabaseFactory +from dataclasses import asdict -from v3.models.models import MPlan import v3.models.messages as messages +from common.database.database_factory import DatabaseFactory from common.models.messages_kernel import ( AgentMessageData, AgentMessageType, AgentType, PlanStatus, ) -from v3.config.settings import orchestration_config from common.utils.event_utils import track_event_if_configured -import uuid -from semantic_kernel.kernel_pydantic import Field - +from v3.config.settings import orchestration_config logger = logging.getLogger(__name__) diff --git a/src/backend/v3/common/services/team_service.py b/src/backend/v3/common/services/team_service.py index 1e7251921..02b9cdc2a 100644 --- a/src/backend/v3/common/services/team_service.py +++ b/src/backend/v3/common/services/team_service.py @@ -1,6 +1,4 @@ -import json import logging -import os import uuid from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Tuple @@ -10,7 +8,6 @@ HttpResponseError, ResourceNotFoundError, ) - from azure.search.documents.indexes import SearchIndexClient from common.config.app_config import config from common.database.database_base import DatabaseBase @@ -243,7 +240,9 @@ async def delete_user_current_team(self, user_id: str) -> bool: self.logger.error("Error deleting current team: %s", str(e)) return False - async def handle_team_selection(self, user_id: str, team_id: str) -> UserCurrentTeam: + async def handle_team_selection( + self, user_id: str, team_id: str + ) -> UserCurrentTeam: """ Set a default team for a user. @@ -317,7 +316,7 @@ def extract_models_from_agent(self, agent: Dict[str, Any]) -> set: """ models = set() - # Skip proxy agents - they don't need deployment models + # Skip proxy agents - they don't need deployment models if agent.get("name", "").lower() == "proxyagent": return models diff --git a/src/backend/v3/config/settings.py b/src/backend/v3/config/settings.py index 69f4a55c4..fbf12e276 100644 --- a/src/backend/v3/config/settings.py +++ b/src/backend/v3/config/settings.py @@ -14,8 +14,10 @@ from fastapi import WebSocket from semantic_kernel.agents.orchestration.magentic import MagenticOrchestration from semantic_kernel.connectors.ai.open_ai import ( - AzureChatCompletion, OpenAIChatPromptExecutionSettings) -from v3.models.messages import WebsocketMessageType, MPlan + AzureChatCompletion, + OpenAIChatPromptExecutionSettings, +) +from v3.models.messages import MPlan, WebsocketMessageType logger = logging.getLogger(__name__) @@ -32,7 +34,7 @@ def __init__(self): self.endpoint = config.AZURE_OPENAI_ENDPOINT self.reasoning_model = config.REASONING_MODEL_NAME self.standard_model = config.AZURE_OPENAI_DEPLOYMENT_NAME - #self.bing_connection_name = config.AZURE_BING_CONNECTION_NAME + # self.bing_connection_name = config.AZURE_BING_CONNECTION_NAME # Create credential self.credential = config.get_azure_credentials() @@ -86,7 +88,9 @@ def __init__(self): self.approvals: Dict[str, bool] = {} # m_plan_id -> approval status self.sockets: Dict[str, WebSocket] = {} # user_id -> WebSocket self.clarifications: Dict[str, str] = {} # m_plan_id -> clarification response - self.max_rounds: int = 20 # Maximum number of replanning rounds 20 needed to accommodate complex tasks + self.max_rounds: int = ( + 20 # Maximum number of replanning rounds 20 needed to accommodate complex tasks + ) def get_current_orchestration(self, user_id: str) -> MagenticOrchestration: """get existing orchestration instance.""" @@ -197,7 +201,6 @@ async def send_status_update_async( ) return - # Convert message to proper format for frontend try: if hasattr(message, "to_dict"): @@ -215,12 +218,8 @@ async def send_status_update_async( except Exception as e: logger.error("Error processing message data: %s", e) message_data = str(message) - - - standard_message = { - "type": message_type, - "data": message_data - } + + standard_message = {"type": message_type, "data": message_data} connection = self.get_connection(process_id) if connection: try: diff --git a/src/backend/v3/magentic_agents/common/lifecycle.py b/src/backend/v3/magentic_agents/common/lifecycle.py index e9f3f7cbe..860dd5666 100644 --- a/src/backend/v3/magentic_agents/common/lifecycle.py +++ b/src/backend/v3/magentic_agents/common/lifecycle.py @@ -1,6 +1,4 @@ -import os from contextlib import AsyncExitStack -from dataclasses import dataclass from typing import Any from azure.ai.projects.aio import AIProjectClient @@ -18,7 +16,7 @@ class MCPEnabledBase: def __init__(self, mcp: MCPConfig | None = None) -> None: self._stack: AsyncExitStack | None = None - self.mcp_cfg: MCPConfig | None = mcp + self.mcp_cfg: MCPConfig | None = mcp self.mcp_plugin: MCPStreamableHttpPlugin | None = None self._agent: Any | None = None # delegate target @@ -34,7 +32,7 @@ async def close(self) -> None: if self._stack is None: return try: - #self.cred.close() + # self.cred.close() await self._stack.aclose() finally: self._stack = None @@ -76,12 +74,12 @@ async def _after_open(self) -> None: async def _enter_mcp_if_configured(self) -> None: if not self.mcp_cfg: return - #headers = self._build_mcp_headers() + # headers = self._build_mcp_headers() plugin = MCPStreamableHttpPlugin( name=self.mcp_cfg.name, description=self.mcp_cfg.description, url=self.mcp_cfg.url, - #headers=headers, + # headers=headers, ) # Enter MCP async context via the stack to ensure correct LIFO cleanup if self._stack is None: @@ -121,4 +119,4 @@ async def open(self) -> "AzureAgentBase": async def close(self) -> None: await self.creds.close() - await super().close() \ No newline at end of file + await super().close() diff --git a/src/backend/v3/magentic_agents/foundry_agent.py b/src/backend/v3/magentic_agents/foundry_agent.py index 2cc62f639..d1bab2336 100644 --- a/src/backend/v3/magentic_agents/foundry_agent.py +++ b/src/backend/v3/magentic_agents/foundry_agent.py @@ -1,11 +1,9 @@ -""" Agent template for building foundry agents with Azure AI Search, Bing, and MCP plugins. """ +"""Agent template for building foundry agents with Azure AI Search, Bing, and MCP plugins.""" -import asyncio import logging from typing import Awaitable, List, Optional -from azure.ai.agents.models import (AzureAISearchTool, BingGroundingTool, - CodeInterpreterToolDefinition) +from azure.ai.agents.models import AzureAISearchTool, CodeInterpreterToolDefinition from semantic_kernel.agents import Agent, AzureAIAgent # pylint: disable=E0611 from v3.magentic_agents.common.lifecycle import AzureAgentBase from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig @@ -14,26 +12,30 @@ # SearchConfig) # exception too broad warning -# pylint: disable=w0718 +# pylint: disable=w0718 + class FoundryAgentTemplate(AzureAgentBase): """Agent that uses Azure AI Search and Bing tools for information retrieval.""" - def __init__(self, agent_name: str, - agent_description: str, - agent_instructions: str, - model_deployment_name: str, - enable_code_interpreter: bool = False, - mcp_config: MCPConfig | None = None, - #bing_config: BingConfig | None = None, - search_config: SearchConfig | None = None) -> None: + def __init__( + self, + agent_name: str, + agent_description: str, + agent_instructions: str, + model_deployment_name: str, + enable_code_interpreter: bool = False, + mcp_config: MCPConfig | None = None, + # bing_config: BingConfig | None = None, + search_config: SearchConfig | None = None, + ) -> None: super().__init__(mcp=mcp_config) self.agent_name = agent_name self.agent_description = agent_description self.agent_instructions = agent_instructions self.model_deployment_name = model_deployment_name self.enable_code_interpreter = enable_code_interpreter - #self.bing = bing_config + # self.bing = bing_config self.mcp = mcp_config self.search = search_config self._search_connection = None @@ -41,7 +43,9 @@ def __init__(self, agent_name: str, self.logger = logging.getLogger(__name__) # input validation if self.model_deployment_name in ["o3", "o4-mini"]: - raise ValueError("The current version of Foundry agents do not support reasoning models.") + raise ValueError( + "The current version of Foundry agents do not support reasoning models." + ) # Uncomment to enable bing grounding capabilities (requires Bing connection in Foundry and uncommenting other code) # async def _make_bing_tool(self) -> Optional[BingGroundingTool]: @@ -66,22 +70,30 @@ async def _make_azure_search_tool(self) -> Optional[AzureAISearchTool]: try: # Get the existing connection by name - self._search_connection = await self.client.connections.get(name=self.search.connection_name) - self.logger.info("Found Azure AI Search connection: %s", self._search_connection.id) + self._search_connection = await self.client.connections.get( + name=self.search.connection_name + ) + self.logger.info( + "Found Azure AI Search connection: %s", self._search_connection.id + ) # Create the Azure AI Search tool search_tool = AzureAISearchTool( index_connection_id=self._search_connection.id, # Try connection_id first - index_name=self.search.index_name + index_name=self.search.index_name, + ) + self.logger.info( + "Azure AI Search tool created for index: %s", self.search.index_name ) - self.logger.info("Azure AI Search tool created for index: %s", self.search.index_name) return search_tool except Exception as ex: self.logger.error( "Azure AI Search tool creation failed: %s | Connection name: %s | Index name: %s | " "Make sure the connection exists in Azure AI Foundry portal", - ex, self.search.connection_name, self.search.index_name + ex, + self.search.connection_name, + self.search.index_name, ) return None @@ -96,9 +108,14 @@ async def _collect_tools_and_resources(self) -> tuple[List, dict]: if search_tool: tools.extend(search_tool.definitions) tool_resources = search_tool.resources - self.logger.info("Added Azure AI Search tools: %d tools", len(search_tool.definitions)) + self.logger.info( + "Added Azure AI Search tools: %d tools", + len(search_tool.definitions), + ) else: - self.logger.error("Something went wrong, Azure AI Search tool not configured") + self.logger.error( + "Something went wrong, Azure AI Search tool not configured" + ) # Add Bing search tool # if self.bing and self.bing.connection_name: @@ -114,7 +131,9 @@ async def _collect_tools_and_resources(self) -> tuple[List, dict]: tools.append(CodeInterpreterToolDefinition()) self.logger.info("Added Code Interpreter tool") except ImportError as ie: - self.logger.error("Code Interpreter tool requires additional dependencies: %s", ie) + self.logger.error( + "Code Interpreter tool requires additional dependencies: %s", ie + ) self.logger.info("Total tools configured: %d", len(tools)) return tools, tool_resources @@ -136,7 +155,7 @@ async def _after_open(self) -> None: description=self.agent_description, instructions=self.agent_instructions, tools=tools, - tool_resources=tool_resources + tool_resources=tool_resources, ) # Add MCP plugins if available @@ -178,15 +197,17 @@ async def fetch_run_details(self, thread_id: str, run_id: str): run = await self.client.agents.runs.get(thread=thread_id, run=run_id) self.logger.error( "Run failure details | status=%s | id=%s | last_error=%s | usage=%s", - getattr(run, 'status', None), + getattr(run, "status", None), run_id, - getattr(run, 'last_error', None), - getattr(run, 'usage', None), + getattr(run, "last_error", None), + getattr(run, "usage", None), ) except Exception as ex: self.logger.error("Could not fetch run details: %s", ex) - async def _get_azure_ai_agent_definition(self, agent_name: str)-> Awaitable[Agent | None]: + async def _get_azure_ai_agent_definition( + self, agent_name: str + ) -> Awaitable[Agent | None]: """ Gets an Azure AI Agent with the specified name and instructions using AIProjectClient if it is already created. """ @@ -222,23 +243,25 @@ async def _get_azure_ai_agent_definition(self, agent_name: str)-> Awaitable[Agen ) -async def create_foundry_agent(agent_name:str, - agent_description:str, - agent_instructions:str, - model_deployment_name:str, - mcp_config:MCPConfig, - #bing_config:BingConfig, - search_config:SearchConfig) -> FoundryAgentTemplate: - +async def create_foundry_agent( + agent_name: str, + agent_description: str, + agent_instructions: str, + model_deployment_name: str, + mcp_config: MCPConfig, + # bing_config:BingConfig, + search_config: SearchConfig, +) -> FoundryAgentTemplate: """Factory function to create and open a ResearcherAgent.""" - agent = FoundryAgentTemplate(agent_name=agent_name, - agent_description=agent_description, - agent_instructions=agent_instructions, - model_deployment_name=model_deployment_name, - enable_code_interpreter=True, - mcp_config=mcp_config, - #bing_config=bing_config, - search_config=search_config) + agent = FoundryAgentTemplate( + agent_name=agent_name, + agent_description=agent_description, + agent_instructions=agent_instructions, + model_deployment_name=model_deployment_name, + enable_code_interpreter=True, + mcp_config=mcp_config, + # bing_config=bing_config, + search_config=search_config, + ) await agent.open() return agent - diff --git a/src/backend/v3/magentic_agents/magentic_agent_factory.py b/src/backend/v3/magentic_agents/magentic_agent_factory.py index c11e18a2f..dd13e27a4 100644 --- a/src/backend/v3/magentic_agents/magentic_agent_factory.py +++ b/src/backend/v3/magentic_agents/magentic_agent_factory.py @@ -1,10 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. -""" Factory for creating and managing magentic agents from JSON configurations.""" +"""Factory for creating and managing magentic agents from JSON configurations.""" import json import logging -import os -from pathlib import Path from types import SimpleNamespace from typing import List, Union @@ -13,6 +11,7 @@ from v3.config.settings import current_user_id from v3.magentic_agents.foundry_agent import FoundryAgentTemplate from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig + # from v3.magentic_agents.models.agent_models import (BingConfig, MCPConfig, # SearchConfig) from v3.magentic_agents.proxy_agent import ProxyAgent @@ -21,79 +20,87 @@ class UnsupportedModelError(Exception): """Raised when an unsupported model is specified.""" - pass class InvalidConfigurationError(Exception): """Raised when agent configuration is invalid.""" - pass class MagenticAgentFactory: """Factory for creating and managing magentic agents from JSON configurations.""" - + def __init__(self): self.logger = logging.getLogger(__name__) self._agent_list: List = [] - + # @staticmethod # def parse_team_config(file_path: Union[str, Path]) -> SimpleNamespace: # """Parse JSON file into objects using SimpleNamespace.""" # with open(file_path, 'r') as f: # data = json.load(f) # return json.loads(json.dumps(data), object_hook=lambda d: SimpleNamespace(**d)) - - async def create_agent_from_config(self, agent_obj: SimpleNamespace) -> Union[FoundryAgentTemplate, ReasoningAgentTemplate, ProxyAgent]: + + async def create_agent_from_config( + self, agent_obj: SimpleNamespace + ) -> Union[FoundryAgentTemplate, ReasoningAgentTemplate, ProxyAgent]: """ Create an agent from configuration object. - + Args: agent_obj: Agent object from parsed JSON (SimpleNamespace) team_model: Model name to determine which template to use - + Returns: Configured agent instance - + Raises: UnsupportedModelError: If model is not supported InvalidConfigurationError: If configuration is invalid """ # Get model from agent config, team model, or environment - deployment_name = getattr(agent_obj, 'deployment_name', None) + deployment_name = getattr(agent_obj, "deployment_name", None) if not deployment_name and agent_obj.name.lower() == "proxyagent": self.logger.info("Creating ProxyAgent") user_id = current_user_id.get() return ProxyAgent(user_id=user_id) - + # Validate supported models supported_models = json.loads(config.SUPPORTED_MODELS) if deployment_name not in supported_models: - raise UnsupportedModelError(f"Model '{deployment_name}' not supported. Supported: {supported_models}") - + raise UnsupportedModelError( + f"Model '{deployment_name}' not supported. Supported: {supported_models}" + ) + # Determine which template to use - use_reasoning = deployment_name.startswith('o') - + use_reasoning = deployment_name.startswith("o") + # Validate reasoning template constraints if use_reasoning: - if getattr(agent_obj, 'use_bing', False) or getattr(agent_obj, 'coding_tools', False): + if getattr(agent_obj, "use_bing", False) or getattr( + agent_obj, "coding_tools", False + ): raise InvalidConfigurationError( f"ReasoningAgentTemplate cannot use Bing search or coding tools. " f"Agent '{agent_obj.name}' has use_bing={getattr(agent_obj, 'use_bing', False)}, " f"coding_tools={getattr(agent_obj, 'coding_tools', False)}" ) - - # Only create configs for explicitly requested capabilities - search_config = SearchConfig.from_env() if getattr(agent_obj, 'use_rag', False) else None - mcp_config = MCPConfig.from_env() if getattr(agent_obj, 'use_mcp', False) else None + search_config = ( + SearchConfig.from_env() if getattr(agent_obj, "use_rag", False) else None + ) + mcp_config = ( + MCPConfig.from_env() if getattr(agent_obj, "use_mcp", False) else None + ) # bing_config = BingConfig.from_env() if getattr(agent_obj, 'use_bing', False) else None - - self.logger.info(f"Creating agent '{agent_obj.name}' with model '{deployment_name}' " - f"(Template: {'Reasoning' if use_reasoning else 'Foundry'})") - + + self.logger.info( + f"Creating agent '{agent_obj.name}' with model '{deployment_name}' " + f"(Template: {'Reasoning' if use_reasoning else 'Foundry'})" + ) + # Create appropriate agent if use_reasoning: # Get reasoning specific configuration @@ -101,81 +108,93 @@ async def create_agent_from_config(self, agent_obj: SimpleNamespace) -> Union[Fo agent = ReasoningAgentTemplate( agent_name=agent_obj.name, - agent_description=getattr(agent_obj, 'description', ''), - agent_instructions=getattr(agent_obj, 'system_message', ''), + agent_description=getattr(agent_obj, "description", ""), + agent_instructions=getattr(agent_obj, "system_message", ""), model_deployment_name=deployment_name, azure_openai_endpoint=azure_openai_endpoint, search_config=search_config, - mcp_config=mcp_config + mcp_config=mcp_config, ) else: agent = FoundryAgentTemplate( agent_name=agent_obj.name, - agent_description=getattr(agent_obj, 'description', ''), - agent_instructions=getattr(agent_obj, 'system_message', ''), + agent_description=getattr(agent_obj, "description", ""), + agent_instructions=getattr(agent_obj, "system_message", ""), model_deployment_name=deployment_name, - enable_code_interpreter=getattr(agent_obj, 'coding_tools', False), + enable_code_interpreter=getattr(agent_obj, "coding_tools", False), mcp_config=mcp_config, - #bing_config=bing_config, - search_config=search_config + # bing_config=bing_config, + search_config=search_config, ) - + await agent.open() - self.logger.info(f"Successfully created and initialized agent '{agent_obj.name}'") + self.logger.info( + f"Successfully created and initialized agent '{agent_obj.name}'" + ) return agent async def get_agents(self, team_config_input: TeamConfiguration) -> List: """ Create and return a team of agents from JSON configuration. - + Args: team_config_input: team configuration object from cosmos db - + Returns: List of initialized agent instances """ # self.logger.info(f"Loading team configuration from: {file_path}") - + try: - + initalized_agents = [] - + for i, agent_cfg in enumerate(team_config_input.agents, 1): try: - self.logger.info(f"Creating agent {i}/{len(team_config_input.agents)}: {agent_cfg.name}") - + self.logger.info( + f"Creating agent {i}/{len(team_config_input.agents)}: {agent_cfg.name}" + ) + agent = await self.create_agent_from_config(agent_cfg) initalized_agents.append(agent) self._agent_list.append(agent) # Keep track for cleanup - - self.logger.info(f"āœ… Agent {i}/{len(team_config_input.agents)} created: {agent_cfg.name}") - + + self.logger.info( + f"āœ… Agent {i}/{len(team_config_input.agents)} created: {agent_cfg.name}" + ) + except (UnsupportedModelError, InvalidConfigurationError) as e: self.logger.warning(f"Skipped agent {agent_cfg.name}: {e}") continue except Exception as e: self.logger.error(f"Failed to create agent {agent_cfg.name}: {e}") continue - - self.logger.info(f"Successfully created {len(initalized_agents)}/{len(team_config_input.agents)} agents for team '{team_config_input.name}'") + + self.logger.info( + f"Successfully created {len(initalized_agents)}/{len(team_config_input.agents)} agents for team '{team_config_input.name}'" + ) return initalized_agents - + except Exception as e: self.logger.error(f"Failed to load team configuration: {e}") raise - @classmethod + @classmethod async def cleanup_all_agents(cls, agent_list: List): """Clean up all created agents.""" cls.logger = logging.getLogger(__name__) cls.logger.info(f"Cleaning up {len(agent_list)} agents") - + for agent in agent_list: try: await agent.close() except Exception as ex: - name = getattr(agent, "agent_name", getattr(agent, "__class__", type("X",(object,),{})).__name__) + name = getattr( + agent, + "agent_name", + getattr(agent, "__class__", type("X", (object,), {})).__name__, + ) cls.logger.warning(f"Error closing agent {name}: {ex}") - + agent_list.clear() cls.logger.info("Agent cleanup completed") diff --git a/src/backend/v3/magentic_agents/models/agent_models.py b/src/backend/v3/magentic_agents/models/agent_models.py index 66cd1cacb..40f19161d 100644 --- a/src/backend/v3/magentic_agents/models/agent_models.py +++ b/src/backend/v3/magentic_agents/models/agent_models.py @@ -1,6 +1,5 @@ """Models for agent configurations.""" -import os from dataclasses import dataclass from common.config.app_config import config @@ -9,6 +8,7 @@ @dataclass(slots=True) class MCPConfig: """Configuration for connecting to an MCP server.""" + url: str = "" name: str = "MCP" description: str = "" @@ -26,7 +26,7 @@ def from_env(cls) -> "MCPConfig": # Raise exception if any required environment variable is missing if not all([url, name, description, tenant_id, client_id]): raise ValueError(f"{cls.__name__} Missing required environment variables") - + return cls( url=url, name=name, @@ -35,6 +35,7 @@ def from_env(cls) -> "MCPConfig": client_id=client_id, ) + # @dataclass(slots=True) # class BingConfig: # """Configuration for connecting to Bing Search.""" @@ -52,9 +53,11 @@ def from_env(cls) -> "MCPConfig": # connection_name=connection_name, # ) + @dataclass(slots=True) class SearchConfig: """Configuration for connecting to Azure AI Search.""" + connection_name: str | None = None endpoint: str | None = None index_name: str | None = None @@ -69,8 +72,10 @@ def from_env(cls) -> "SearchConfig": # Raise exception if any required environment variable is missing if not all([connection_name, index_name, endpoint]): - raise ValueError(f"{cls.__name__} Missing required Azure Search environment variables") - + raise ValueError( + f"{cls.__name__} Missing required Azure Search environment variables" + ) + return cls( connection_name=connection_name, index_name=index_name, diff --git a/src/backend/v3/magentic_agents/proxy_agent.py b/src/backend/v3/magentic_agents/proxy_agent.py index 8ae9a28eb..d18f5eb92 100644 --- a/src/backend/v3/magentic_agents/proxy_agent.py +++ b/src/backend/v3/magentic_agents/proxy_agent.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -""" Proxy agent that prompts for human clarification.""" +"""Proxy agent that prompts for human clarification.""" import asyncio import logging @@ -9,39 +9,50 @@ from pydantic import Field from semantic_kernel.agents import ( # pylint: disable=no-name-in-module - AgentResponseItem, AgentThread) + AgentResponseItem, + AgentThread, +) from semantic_kernel.agents.agent import Agent -from semantic_kernel.contents import (AuthorRole, ChatMessageContent, - StreamingChatMessageContent) +from semantic_kernel.contents import ( + AuthorRole, + ChatMessageContent, + StreamingChatMessageContent, +) from semantic_kernel.contents.chat_history import ChatHistory -from semantic_kernel.contents.history_reducer.chat_history_reducer import \ - ChatHistoryReducer -from semantic_kernel.exceptions.agent_exceptions import \ - AgentThreadOperationException +from semantic_kernel.contents.history_reducer.chat_history_reducer import ( + ChatHistoryReducer, +) +from semantic_kernel.exceptions.agent_exceptions import AgentThreadOperationException from typing_extensions import override -from v3.callbacks.response_handlers import (agent_response_callback, - streaming_agent_response_callback) -from v3.config.settings import (connection_config, current_user_id, - orchestration_config) -from v3.models.messages import (UserClarificationRequest, - UserClarificationResponse, WebsocketMessageType) +from v3.callbacks.response_handlers import ( + agent_response_callback, + streaming_agent_response_callback, +) +from v3.config.settings import connection_config, current_user_id, orchestration_config +from v3.models.messages import ( + UserClarificationRequest, + UserClarificationResponse, + WebsocketMessageType, +) class DummyAgentThread(AgentThread): """Dummy thread implementation for proxy agent.""" - - def __init__(self, chat_history: ChatHistory | None = None, thread_id: str | None = None): + + def __init__( + self, chat_history: ChatHistory | None = None, thread_id: str | None = None + ): super().__init__() self._chat_history = chat_history if chat_history is not None else ChatHistory() self._id: str = thread_id or f"thread_{uuid.uuid4().hex}" self._is_deleted = False self.logger = logging.getLogger(__name__) - + @override async def _create(self) -> str: """Starts the thread and returns its ID.""" return self._id - + @override async def _delete(self) -> None: """Ends the current thread.""" @@ -67,7 +78,9 @@ async def get_messages(self) -> AsyncIterable[ChatMessageContent]: An async iterable of ChatMessageContent. """ if self._is_deleted: - raise AgentThreadOperationException("Cannot retrieve chat history, since the thread has been deleted.") + raise AgentThreadOperationException( + "Cannot retrieve chat history, since the thread has been deleted." + ) if self._id is None: await self.create() for message in self._chat_history.messages: @@ -76,7 +89,9 @@ async def get_messages(self) -> AsyncIterable[ChatMessageContent]: async def reduce(self) -> ChatHistory | None: """Reduce the chat history to a smaller size.""" if self._id is None: - raise AgentThreadOperationException("Cannot reduce chat history, since the thread is not currently active.") + raise AgentThreadOperationException( + "Cannot reduce chat history, since the thread is not currently active." + ) if not isinstance(self._chat_history, ChatHistoryReducer): return None return await self._chat_history.reduce() @@ -84,18 +99,21 @@ async def reduce(self) -> ChatHistory | None: class ProxyAgentResponseItem: """Response item wrapper for proxy agent responses.""" - + def __init__(self, message: ChatMessageContent, thread: AgentThread): self.message = message self.thread = thread self.logger = logging.getLogger(__name__) + class ProxyAgent(Agent): """Simple proxy agent that prompts for human clarification.""" # Declare as Pydantic field - user_id: Optional[str] = Field(default=None, description="User ID for WebSocket messaging") - + user_id: Optional[str] = Field( + default=None, description="User ID for WebSocket messaging" + ) + def __init__(self, user_id: str = None, **kwargs): # Get user_id from parameter or context, fallback to empty string effective_user_id = user_id or current_user_id.get() or "" @@ -103,22 +121,24 @@ def __init__(self, user_id: str = None, **kwargs): name="ProxyAgent", description="Call this agent when you need to clarify requests by asking the human user for more information. Ask it for more details about any unclear requirements, missing information, or if you need the user to elaborate on any aspect of the task.", user_id=effective_user_id, - **kwargs + **kwargs, ) self.instructions = "" - def _create_message_content(self, content: str, thread_id: str = None) -> ChatMessageContent: + def _create_message_content( + self, content: str, thread_id: str = None + ) -> ChatMessageContent: """Create a ChatMessageContent with proper metadata.""" return ChatMessageContent( role=AuthorRole.ASSISTANT, content=content, name=self.name, - metadata={"thread_id": thread_id} if thread_id else {} + metadata={"thread_id": thread_id} if thread_id else {}, ) async def _trigger_response_callbacks(self, message_content: ChatMessageContent): """Manually trigger the same response callbacks used by other agents.""" - # Get current user_id dynamically instead of using stored value + # Get current user_id dynamically instead of using stored value current_user = current_user_id.get() or self.user_id or "" # Trigger the standard agent response callback @@ -129,14 +149,15 @@ async def _trigger_streaming_callbacks(self, content: str, is_final: bool = Fals # Get current user_id dynamically instead of using stored value current_user = current_user_id.get() or self.user_id or "" streaming_message = StreamingChatMessageContent( - role=AuthorRole.ASSISTANT, - content=content, - name=self.name, - choice_index=0 + role=AuthorRole.ASSISTANT, content=content, name=self.name, choice_index=0 ) - await streaming_agent_response_callback(streaming_message, is_final, current_user) - - async def invoke(self, message: str,*, thread: AgentThread | None = None,**kwargs) -> AsyncIterator[ChatMessageContent]: + await streaming_agent_response_callback( + streaming_message, is_final, current_user + ) + + async def invoke( + self, message: str, *, thread: AgentThread | None = None, **kwargs + ) -> AsyncIterator[ChatMessageContent]: """Ask human user for clarification about the message.""" thread = await self._ensure_thread_exists_with_messages( @@ -151,31 +172,36 @@ async def invoke(self, message: str,*, thread: AgentThread | None = None,**kwarg clarification_message = UserClarificationRequest( question=clarification_request, - request_id=str(uuid.uuid4()) # Unique ID for the request + request_id=str(uuid.uuid4()), # Unique ID for the request ) # Send the approval request to the user's WebSocket - await connection_config.send_status_update_async({ - "type": WebsocketMessageType.USER_CLARIFICATION_REQUEST, - "data": clarification_message - }, user_id=current_user_id.get(), message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST) + await connection_config.send_status_update_async( + { + "type": WebsocketMessageType.USER_CLARIFICATION_REQUEST, + "data": clarification_message, + }, + user_id=current_user_id.get(), + message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST, + ) # Get human input - human_response = await self._wait_for_user_clarification(clarification_message.request_id) - + human_response = await self._wait_for_user_clarification( + clarification_message.request_id + ) + if not human_response: human_response = "No additional clarification provided." - + response = f"Human clarification: {human_response}" chat_message = self._create_message_content(response, thread.id) - - yield AgentResponseItem( - message=chat_message, - thread=thread - ) - - async def invoke_stream(self, messages, thread=None, **kwargs) -> AsyncIterator[ProxyAgentResponseItem]: + + yield AgentResponseItem(message=chat_message, thread=thread) + + async def invoke_stream( + self, messages, thread=None, **kwargs + ) -> AsyncIterator[ProxyAgentResponseItem]: """Stream version - handles thread management for orchestration.""" thread = await self._ensure_thread_exists_with_messages( @@ -187,7 +213,11 @@ async def invoke_stream(self, messages, thread=None, **kwargs) -> AsyncIterator[ # Extract message content if isinstance(messages, list) and messages: - message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1]) + message = ( + messages[-1].content + if hasattr(messages[-1], "content") + else str(messages[-1]) + ) elif isinstance(messages, str): message = messages else: @@ -198,56 +228,65 @@ async def invoke_stream(self, messages, thread=None, **kwargs) -> AsyncIterator[ clarification_message = UserClarificationRequest( question=clarification_request, - request_id=str(uuid.uuid4()) # Unique ID for the request + request_id=str(uuid.uuid4()), # Unique ID for the request ) # Send the approval request to the user's WebSocket # The user_id will be automatically retrieved from context - await connection_config.send_status_update_async({ - "type": WebsocketMessageType.USER_CLARIFICATION_REQUEST, - "data": clarification_message - }, user_id=current_user_id.get(), message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST) + await connection_config.send_status_update_async( + { + "type": WebsocketMessageType.USER_CLARIFICATION_REQUEST, + "data": clarification_message, + }, + user_id=current_user_id.get(), + message_type=WebsocketMessageType.USER_CLARIFICATION_REQUEST, + ) # Get human input - replace with websocket call when available - human_response = await self._wait_for_user_clarification(clarification_message.request_id) - + human_response = await self._wait_for_user_clarification( + clarification_message.request_id + ) + if not human_response: human_response = "No additional clarification provided." - + response = f"Human clarification: {human_response}" chat_message = self._create_message_content(response, thread.id) - - yield AgentResponseItem( - message=chat_message, - thread=thread - ) - async def _wait_for_user_clarification(self, request_id:str) -> Optional[UserClarificationResponse]: + yield AgentResponseItem(message=chat_message, thread=thread) + + async def _wait_for_user_clarification( + self, request_id: str + ) -> Optional[UserClarificationResponse]: """Wait for user clarification response.""" # To do: implement timeout and error handling if request_id not in orchestration_config.clarifications: orchestration_config.clarifications[request_id] = None while orchestration_config.clarifications[request_id] is None: await asyncio.sleep(0.2) - return UserClarificationResponse(request_id=request_id,answer=orchestration_config.clarifications[request_id]) - + return UserClarificationResponse( + request_id=request_id, + answer=orchestration_config.clarifications[request_id], + ) + async def get_response(self, chat_history, **kwargs): """Get response from the agent - required by Agent base class.""" # Extract the latest user message - latest_message = chat_history.messages[-1].content if chat_history.messages else "" - + latest_message = ( + chat_history.messages[-1].content if chat_history.messages else "" + ) + # Use our invoke method to get the response async for response in self.invoke(latest_message, **kwargs): return response - + # Fallback if no response generated return ChatMessageContent( - role=AuthorRole.ASSISTANT, - content="No clarification provided." + role=AuthorRole.ASSISTANT, content="No clarification provided." ) - + async def create_proxy_agent(user_id: str = None): """Factory function for human proxy agent.""" - return ProxyAgent(user_id=user_id) \ No newline at end of file + return ProxyAgent(user_id=user_id) diff --git a/src/backend/v3/magentic_agents/reasoning_agent.py b/src/backend/v3/magentic_agents/reasoning_agent.py index 00c8659ce..915756cab 100644 --- a/src/backend/v3/magentic_agents/reasoning_agent.py +++ b/src/backend/v3/magentic_agents/reasoning_agent.py @@ -1,12 +1,9 @@ import logging -import os -from azure.identity import DefaultAzureCredential as SyncDefaultAzureCredential from common.config.app_config import config from semantic_kernel import Kernel from semantic_kernel.agents import ChatCompletionAgent # pylint: disable=E0611 from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion -from semantic_kernel.connectors.azure_ai_search import AzureAISearchCollection from v3.magentic_agents.common.lifecycle import MCPEnabledBase from v3.magentic_agents.models.agent_models import MCPConfig, SearchConfig from v3.magentic_agents.reasoning_search import ReasoningSearch diff --git a/src/backend/v3/magentic_agents/reasoning_search.py b/src/backend/v3/magentic_agents/reasoning_search.py index 38a6a28e0..7f944e7f5 100644 --- a/src/backend/v3/magentic_agents/reasoning_search.py +++ b/src/backend/v3/magentic_agents/reasoning_search.py @@ -4,47 +4,49 @@ """ from azure.core.credentials import AzureKeyCredential -from azure.identity import DefaultAzureCredential as SyncDefaultAzureCredential from azure.search.documents import SearchClient -from azure.search.documents.indexes import SearchIndexClient from semantic_kernel import Kernel -from semantic_kernel.connectors.ai.open_ai import AzureTextEmbedding -from semantic_kernel.connectors.azure_ai_search import ( - AzureAISearchCollection, AzureAISearchStore) from semantic_kernel.functions import kernel_function from v3.magentic_agents.models.agent_models import SearchConfig class ReasoningSearch: """Handles Azure AI Search integration for reasoning agents.""" - + def __init__(self, search_config: SearchConfig | None = None): self.search_config = search_config self.search_client: SearchClient | None = None - + async def initialize(self, kernel: Kernel) -> bool: """Initialize the search collection with embeddings and add it to the kernel.""" - if not self.search_config or not self.search_config.endpoint or not self.search_config.index_name: + if ( + not self.search_config + or not self.search_config.endpoint + or not self.search_config.index_name + ): print("Search configuration not available") return False - + try: - credential = SyncDefaultAzureCredential() - self.search_client = SearchClient(endpoint=self.search_config.endpoint, - credential=AzureKeyCredential(self.search_config.api_key), - index_name=self.search_config.index_name) - + self.search_client = SearchClient( + endpoint=self.search_config.endpoint, + credential=AzureKeyCredential(self.search_config.api_key), + index_name=self.search_config.index_name, + ) + # Add this class as a plugin so the agent can call search_documents kernel.add_plugin(self, plugin_name="knowledge_search") - - print(f"Added Azure AI Search plugin for index: {self.search_config.index_name}") + + print( + f"Added Azure AI Search plugin for index: {self.search_config.index_name}" + ) return True - + except Exception as ex: print(f"Could not initialize Azure AI Search: {ex}") return False - + @kernel_function( name="search_documents", description="Search the knowledge base for relevant documents and information. Use this when you need to find specific information from internal documents or data.", @@ -53,37 +55,39 @@ async def search_documents(self, query: str, limit: str = "3") -> str: """Search function that the agent can invoke to find relevant documents.""" if not self.search_client: return "Search service is not available." - + try: limit_int = int(limit) search_results = [] - results = self.search_client.search( - search_text=query, - query_type= "simple", + results = self.search_client.search( + search_text=query, + query_type="simple", select=["content"], - top=limit_int - ) - - for result in results: + top=limit_int, + ) + + for result in results: search_results.append(f"content: {result['content']}") - + if not search_results: return f"No relevant documents found for query: '{query}'" - + return search_results - + except Exception as ex: return f"Search failed: {str(ex)}" - + def is_available(self) -> bool: """Check if search functionality is available.""" return self.search_client is not None # Simple factory function -async def create_reasoning_search(kernel: Kernel, search_config: SearchConfig | None) -> ReasoningSearch: +async def create_reasoning_search( + kernel: Kernel, search_config: SearchConfig | None +) -> ReasoningSearch: """Create and initialize a ReasoningSearch instance.""" search = ReasoningSearch(search_config) await search.initialize(kernel) - return search \ No newline at end of file + return search diff --git a/src/backend/v3/models/messages.py b/src/backend/v3/models/messages.py index 4605723f1..8eb4187c8 100644 --- a/src/backend/v3/models/messages.py +++ b/src/backend/v3/models/messages.py @@ -1,18 +1,19 @@ """Messages from the backend to the frontend via WebSocket.""" -import uuid +import time from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Any, Dict, List, Literal, Optional -import time -from semantic_kernel.kernel_pydantic import Field, KernelBaseModel +from typing import Any, Dict, List, Optional + from common.models.messages_kernel import AgentMessageType +from semantic_kernel.kernel_pydantic import KernelBaseModel from v3.models.models import MPlan, PlanStatus @dataclass(slots=True) class AgentMessage: """Message from the backend to the frontend via WebSocket.""" + agent_name: str timestamp: str content: str @@ -21,19 +22,25 @@ def to_dict(self) -> Dict[str, Any]: """Convert the AgentMessage to a dictionary for JSON serialization.""" return asdict(self) + @dataclass(slots=True) class AgentStreamStart: """Start of a streaming message from the backend to the frontend via WebSocket.""" + agent_name: str + @dataclass(slots=True) class AgentStreamEnd: """End of a streaming message from the backend to the frontend via WebSocket.""" + agent_name: str + @dataclass(slots=True) class AgentMessageStreaming: """Streaming message from the backend to the frontend via WebSocket.""" + agent_name: str content: str is_final: bool = False @@ -42,19 +49,23 @@ def to_dict(self) -> Dict[str, Any]: """Convert the AgentMessageStreaming to a dictionary for JSON serialization.""" return asdict(self) + @dataclass(slots=True) class AgentToolMessage: """Message from an agent using a tool.""" + agent_name: str - tool_calls: List['AgentToolCall'] = field(default_factory=list) + tool_calls: List["AgentToolCall"] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: """Convert the AgentToolMessage to a dictionary for JSON serialization.""" return asdict(self) - + + @dataclass(slots=True) class AgentToolCall: """Message representing a tool call from an agent.""" + tool_name: str arguments: Dict[str, Any] @@ -62,52 +73,66 @@ def to_dict(self) -> Dict[str, Any]: """Convert the AgentToolCall to a dictionary for JSON serialization.""" return asdict(self) + @dataclass(slots=True) class PlanApprovalRequest: """Request for plan approval from the frontend.""" + plan: MPlan status: PlanStatus context: dict | None = None + @dataclass(slots=True) class PlanApprovalResponse: """Response for plan approval from the frontend.""" + m_plan_id: str approved: bool feedback: str | None = None plan_id: str | None = None + @dataclass(slots=True) class ReplanApprovalRequest: """Request for replan approval from the frontend.""" + new_plan: MPlan reason: str context: dict | None = None + @dataclass(slots=True) -class ReplanApprovalResponse: +class ReplanApprovalResponse: """Response for replan approval from the frontend.""" + plan_id: str approved: bool feedback: str | None = None + @dataclass(slots=True) class UserClarificationRequest: """Request for user clarification from the frontend.""" + question: str request_id: str + @dataclass(slots=True) class UserClarificationResponse: """Response for user clarification from the frontend.""" + request_id: str answer: str = "" plan_id: str = "" m_plan_id: str = "" + @dataclass(slots=True) class FinalResultMessage: """Final result message from the backend to the frontend.""" + content: str # Changed from 'result' to 'content' to match frontend expectations status: str = "completed" # Added status field (defaults to 'completed') timestamp: Optional[float] = None # Added timestamp field @@ -119,7 +144,7 @@ def to_dict(self) -> Dict[str, Any]: data = { "content": self.content, "status": self.status, - "timestamp": self.timestamp or time.time() + "timestamp": self.timestamp or time.time(), } if self.summary: data["summary"] = self.summary @@ -137,9 +162,11 @@ class ApprovalRequest(KernelBaseModel): action: str agent_name: str + @dataclass(slots=True) class AgentMessageResponse: """Response message representing an agent's message.""" + plan_id: str agent: str content: str @@ -151,7 +178,8 @@ class AgentMessageResponse: class WebsocketMessageType(str, Enum): """Types of WebSocket messages.""" - SYSTEM_MESSAGE = "system_message" + + SYSTEM_MESSAGE = "system_message" AGENT_MESSAGE = "agent_message" AGENT_STREAM_START = "agent_stream_start" AGENT_STREAM_END = "agent_stream_end" diff --git a/src/backend/v3/models/models.py b/src/backend/v3/models/models.py index 26e512547..adcf1fe88 100644 --- a/src/backend/v3/models/models.py +++ b/src/backend/v3/models/models.py @@ -1,9 +1,6 @@ import uuid -from datetime import datetime, timezone from enum import Enum -from typing import Any, Dict, List, Literal, Optional - -from dataclasses import asdict, dataclass, field +from typing import List from pydantic import BaseModel, Field @@ -11,7 +8,7 @@ class PlanStatus(str, Enum): CREATED = "created" QUEUED = "queued" - RUNNING = "running" + RUNNING = "running" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" @@ -19,12 +16,14 @@ class PlanStatus(str, Enum): class MStep(BaseModel): """model of a step in a plan""" + agent: str = "" action: str = "" class MPlan(BaseModel): """model of a plan""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) user_id: str = "" team_id: str = "" @@ -34,4 +33,3 @@ class MPlan(BaseModel): team: List[str] = [] facts: str = "" steps: List[MStep] = [] - diff --git a/src/backend/v3/models/orchestration_models.py b/src/backend/v3/models/orchestration_models.py index ef9f0759a..8c63c1234 100644 --- a/src/backend/v3/models/orchestration_models.py +++ b/src/backend/v3/models/orchestration_models.py @@ -1,17 +1,16 @@ -from enum import Enum -from typing import List, Optional, TypedDict +from typing import List, Optional -from semantic_kernel.kernel_pydantic import Field, KernelBaseModel +from semantic_kernel.kernel_pydantic import KernelBaseModel +# Add other agents as needed - # Add other agents as needed - # Define agents drawing on the magentic team output -class AgentDefinition: +class AgentDefinition: def __init__(self, name, description): self.name = name self.description = description + def __repr__(self): return f"Agent(name={self.name!r}, description={self.description!r})" @@ -22,11 +21,10 @@ class PlannerResponseStep(KernelBaseModel): action: str - class PlannerResponsePlan(KernelBaseModel): request: str team: List[AgentDefinition] facts: str steps: List[PlannerResponseStep] summary_plan_and_steps: str - human_clarification_request: Optional[str] = None \ No newline at end of file + human_clarification_request: Optional[str] = None diff --git a/src/backend/v3/orchestration/helper/plan_to_mplan_converter.py b/src/backend/v3/orchestration/helper/plan_to_mplan_converter.py index 13c0d7a69..bc1dd5346 100644 --- a/src/backend/v3/orchestration/helper/plan_to_mplan_converter.py +++ b/src/backend/v3/orchestration/helper/plan_to_mplan_converter.py @@ -33,9 +33,9 @@ class PlanToMPlanConverter: """ - BULLET_RE = re.compile(r'^(?P\s*)[-•*]\s+(?P.+)$') - BOLD_AGENT_RE = re.compile(r'\*\*([A-Za-z0-9_]+)\*\*') - STRIP_BULLET_MARKER_RE = re.compile(r'^[-•*]\s+') + BULLET_RE = re.compile(r"^(?P\s*)[-•*]\s+(?P.+)$") + BOLD_AGENT_RE = re.compile(r"\*\*([A-Za-z0-9_]+)\*\*") + STRIP_BULLET_MARKER_RE = re.compile(r"^[-•*]\s+") def __init__( self, @@ -150,7 +150,7 @@ def _try_bold_agent(self, text: str) -> (Optional[str], str): candidate = m.group(1) canonical = self._team_lookup.get(candidate.lower()) if canonical: # valid agent - cleaned = text[:m.start()] + text[m.end():] + cleaned = text[: m.start()] + text[m.end() :] return canonical, cleaned.strip() return None, text @@ -169,7 +169,7 @@ def _finalize_action(self, action: str) -> str: if self.trim_actions: action = action.strip() if self.collapse_internal_whitespace: - action = re.sub(r'\s+', ' ', action) + action = re.sub(r"\s+", " ", action) return action # --------------- Convenience (static) --------------- # @@ -191,4 +191,4 @@ def convert( task=task, facts=facts, **kwargs, - ).parse(plan_text) \ No newline at end of file + ).parse(plan_text) diff --git a/src/backend/v3/orchestration/human_approval_manager.py b/src/backend/v3/orchestration/human_approval_manager.py index 54c8123fe..7eef305ad 100644 --- a/src/backend/v3/orchestration/human_approval_manager.py +++ b/src/backend/v3/orchestration/human_approval_manager.py @@ -5,26 +5,29 @@ import asyncio import logging -import re from typing import Any, Optional import v3.models.messages as messages from semantic_kernel.agents.orchestration.magentic import ( - MagenticContext, ProgressLedger, ProgressLedgerItem, - StandardMagenticManager) + MagenticContext, + ProgressLedger, + ProgressLedgerItem, + StandardMagenticManager, +) from semantic_kernel.agents.orchestration.prompts._magentic_prompts import ( - ORCHESTRATOR_FINAL_ANSWER_PROMPT, ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT, - ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT) + ORCHESTRATOR_FINAL_ANSWER_PROMPT, + ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT, + ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT, +) from semantic_kernel.contents import ChatMessageContent -from v3.config.settings import (connection_config, current_user_id, - orchestration_config) -from v3.models.models import MPlan, MStep -from v3.orchestration.helper.plan_to_mplan_converter import \ - PlanToMPlanConverter +from v3.config.settings import connection_config, current_user_id, orchestration_config +from v3.models.models import MPlan +from v3.orchestration.helper.plan_to_mplan_converter import PlanToMPlanConverter # Using a module level logger to avoid pydantic issues around inherited fields logger = logging.getLogger(__name__) + # Create a progress ledger that indicates the request is satisfied (task completed) class HumanApprovalMagenticManager(StandardMagenticManager): """ @@ -35,7 +38,7 @@ class HumanApprovalMagenticManager(StandardMagenticManager): # Define Pydantic fields to avoid validation errors approval_enabled: bool = True magentic_plan: Optional[MPlan] = None - current_user_id: Optional[str] = None + current_user_id: Optional[str] = None def __init__(self, *args, **kwargs): # Remove any custom kwargs before passing to parent @@ -43,13 +46,12 @@ def __init__(self, *args, **kwargs): plan_append = """ IMPORTANT: Never ask the user for information or clarification until all agents on the team have been asked first. -EXAMPLE: If the user request involves product information, first ask all agents on the team to provide the information. +EXAMPLE: If the user request involves product information, first ask all agents on the team to provide the information. Do not ask the user unless all agents have been consulted and the information is still missing. Plan steps should always include a bullet point, followed by an agent name, followed by a description of the action -to be taken. If a step involves multiple actions, separate them into distinct steps with an agent included in each step. If the step is taken by an agent that -is not part of the team, such as the MagenticManager, please always list the MagenticManager as the agent for that step. At any time, if more information is -needed from the user, use the ProxyAgent to request this information. +to be taken. If a step involves multiple actions, separate them into distinct steps with an agent included in each step. +If the step is taken by an agent that is not part of the team, such as the MagenticManager, please always list the MagenticManager as the agent for that step. At any time, if more information is needed from the user, use the ProxyAgent to request this information. Here is an example of a well-structured plan: - **EnhancedResearchAgent** to gather authoritative data on the latest industry trends and best practices in employee onboarding @@ -66,9 +68,13 @@ def __init__(self, *args, **kwargs): """ # kwargs["task_ledger_facts_prompt"] = ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT + facts_append - kwargs['task_ledger_plan_prompt'] = ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT + plan_append - kwargs['task_ledger_plan_update_prompt'] = ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT + plan_append - kwargs['final_answer_prompt'] = ORCHESTRATOR_FINAL_ANSWER_PROMPT + final_append + kwargs["task_ledger_plan_prompt"] = ( + ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT + plan_append + ) + kwargs["task_ledger_plan_update_prompt"] = ( + ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT + plan_append + ) + kwargs["final_answer_prompt"] = ORCHESTRATOR_FINAL_ANSWER_PROMPT + final_append super().__init__(*args, **kwargs) @@ -78,7 +84,7 @@ async def plan(self, magentic_context: MagenticContext) -> Any: """ # Extract task text from the context task_text = magentic_context.task - if hasattr(task_text, 'content'): + if hasattr(task_text, "content"): task_text = task_text.content elif not isinstance(task_text, str): task_text = str(task_text) @@ -90,9 +96,9 @@ async def plan(self, magentic_context: MagenticContext) -> Any: # First, let the parent create the actual plan logger.info(" Creating execution plan...") plan = await super().plan(magentic_context) - logger.info(" Plan created: %s",plan) + logger.info(" Plan created: %s", plan) - self.magentic_plan = self.plan_to_obj( magentic_context, self.task_ledger) + self.magentic_plan = self.plan_to_obj(magentic_context, self.task_ledger) self.magentic_plan.user_id = current_user_id.get() @@ -100,23 +106,27 @@ async def plan(self, magentic_context: MagenticContext) -> Any: approval_message = messages.PlanApprovalRequest( plan=self.magentic_plan, status="PENDING_APPROVAL", - context={ - "task": task_text, - "participant_descriptions": magentic_context.participant_descriptions - } if hasattr(magentic_context, 'participant_descriptions') else {} + context=( + { + "task": task_text, + "participant_descriptions": magentic_context.participant_descriptions, + } + if hasattr(magentic_context, "participant_descriptions") + else {} + ), ) try: - orchestration_config.plans[self.magentic_plan.id] = self.magentic_plan + orchestration_config.plans[self.magentic_plan.id] = self.magentic_plan except Exception as e: logger.error("Error processing plan approval: %s", e) - # Send the approval request to the user's WebSocket # The user_id will be automatically retrieved from context await connection_config.send_status_update_async( message=approval_message, user_id=current_user_id.get(), - message_type=messages.WebsocketMessageType.PLAN_APPROVAL_REQUEST) + message_type=messages.WebsocketMessageType.PLAN_APPROVAL_REQUEST, + ) # Wait for user approval approval_response = await self._wait_for_user_approval(approval_message.plan.id) @@ -126,15 +136,19 @@ async def plan(self, magentic_context: MagenticContext) -> Any: return plan else: logger.debug("Plan execution cancelled by user") - await connection_config.send_status_update_async({ - "type": messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE, - "data": approval_response - }, user_id=current_user_id.get(), message_type=messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE) - raise Exception("Plan execution cancelled by user") - - async def replan(self,magentic_context: MagenticContext) -> Any: - """ - Override to add websocket messages for replanning events. + await connection_config.send_status_update_async( + { + "type": messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE, + "data": approval_response, + }, + user_id=current_user_id.get(), + message_type=messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE, + ) + raise Exception("Plan execution cancelled by user") + + async def replan(self, magentic_context: MagenticContext) -> Any: + """ + Override to add websocket messages for replanning events. """ logger.info("\nHuman-in-the-Loop Magentic Manager replanned:") @@ -142,33 +156,45 @@ async def replan(self,magentic_context: MagenticContext) -> Any: logger.info("Replanned: %s", replan) return replan - async def create_progress_ledger(self, magentic_context: MagenticContext) -> ProgressLedger: - """ Check for max rounds exceeded and send final message if so. """ + async def create_progress_ledger( + self, magentic_context: MagenticContext + ) -> ProgressLedger: + """Check for max rounds exceeded and send final message if so.""" if magentic_context.round_count >= orchestration_config.max_rounds: # Send final message to user final_message = messages.FinalResultMessage( content="Process terminated: Maximum rounds exceeded", status="terminated", - summary=f"Stopped after {magentic_context.round_count} rounds (max: {orchestration_config.max_rounds})" + summary=f"Stopped after {magentic_context.round_count} rounds (max: {orchestration_config.max_rounds})", ) await connection_config.send_status_update_async( - message= final_message, + message=final_message, user_id=current_user_id.get(), - message_type=messages.WebsocketMessageType.FINAL_RESULT_MESSAGE) + message_type=messages.WebsocketMessageType.FINAL_RESULT_MESSAGE, + ) return ProgressLedger( - is_request_satisfied=ProgressLedgerItem(reason="Maximum rounds exceeded", answer=True), + is_request_satisfied=ProgressLedgerItem( + reason="Maximum rounds exceeded", answer=True + ), is_in_loop=ProgressLedgerItem(reason="Terminating", answer=False), - is_progress_being_made=ProgressLedgerItem(reason="Terminating", answer=False), + is_progress_being_made=ProgressLedgerItem( + reason="Terminating", answer=False + ), next_speaker=ProgressLedgerItem(reason="Task complete", answer=""), - instruction_or_question=ProgressLedgerItem(reason="Task complete", answer="Process terminated due to maximum rounds exceeded") + instruction_or_question=ProgressLedgerItem( + reason="Task complete", + answer="Process terminated due to maximum rounds exceeded", + ), ) return await super().create_progress_ledger(magentic_context) # plan_id will not be optional in future - async def _wait_for_user_approval(self, m_plan_id: Optional[str] = None) -> Optional[messages.PlanApprovalResponse]: + async def _wait_for_user_approval( + self, m_plan_id: Optional[str] = None + ) -> Optional[messages.PlanApprovalResponse]: """Wait for user approval response.""" # To do: implement timeout and error handling @@ -176,21 +202,29 @@ async def _wait_for_user_approval(self, m_plan_id: Optional[str] = None) -> Opti orchestration_config.approvals[m_plan_id] = None while orchestration_config.approvals[m_plan_id] is None: await asyncio.sleep(0.2) - return messages.PlanApprovalResponse(approved=orchestration_config.approvals[m_plan_id], m_plan_id=m_plan_id) + return messages.PlanApprovalResponse( + approved=orchestration_config.approvals[m_plan_id], m_plan_id=m_plan_id + ) - async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessageContent: + async def prepare_final_answer( + self, magentic_context: MagenticContext + ) -> ChatMessageContent: """ Override to ensure final answer is prepared after all steps are executed. """ logger.info("\n Magentic Manager - Preparing final answer...") return await super().prepare_final_answer(magentic_context) - def plan_to_obj(self, magentic_context, ledger) -> MPlan: - """ Convert the generated plan from the ledger into a structured MPlan object. """ + """Convert the generated plan from the ledger into a structured MPlan object.""" - return_plan: MPlan = PlanToMPlanConverter.convert(plan_text=ledger.plan.content,facts=ledger.facts.content, team=list(magentic_context.participant_descriptions.keys()), task=magentic_context.task) + return_plan: MPlan = PlanToMPlanConverter.convert( + plan_text=ledger.plan.content, + facts=ledger.facts.content, + team=list(magentic_context.participant_descriptions.keys()), + task=magentic_context.task, + ) # # get the request text from the ledger # if hasattr(magentic_context, 'task'): diff --git a/src/backend/v3/orchestration/orchestration_manager.py b/src/backend/v3/orchestration/orchestration_manager.py index 82f2eeb9d..c62452e71 100644 --- a/src/backend/v3/orchestration/orchestration_manager.py +++ b/src/backend/v3/orchestration/orchestration_manager.py @@ -1,10 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. -""" Orchestration manager to handle the orchestration logic. """ - +"""Orchestration manager to handle the orchestration logic.""" import asyncio import contextvars import logging -import os import uuid from contextvars import ContextVar from typing import List, Optional @@ -14,22 +12,30 @@ from common.models.messages_kernel import TeamConfiguration from semantic_kernel.agents.orchestration.magentic import MagenticOrchestration from semantic_kernel.agents.runtime import InProcessRuntime + # Create custom execution settings to fix schema issues from semantic_kernel.connectors.ai.open_ai import ( - AzureChatCompletion, OpenAIChatPromptExecutionSettings) -from semantic_kernel.contents import (ChatMessageContent, - StreamingChatMessageContent) -from v3.callbacks.response_handlers import (agent_response_callback, - streaming_agent_response_callback) -from v3.config.settings import (config, connection_config, current_user_id, - orchestration_config) + AzureChatCompletion, + OpenAIChatPromptExecutionSettings, +) +from semantic_kernel.contents import ChatMessageContent, StreamingChatMessageContent +from v3.callbacks.response_handlers import ( + agent_response_callback, + streaming_agent_response_callback, +) +from v3.config.settings import ( + connection_config, + orchestration_config, +) from v3.magentic_agents.magentic_agent_factory import MagenticAgentFactory from v3.models.messages import WebsocketMessageType -from v3.orchestration.human_approval_manager import \ - HumanApprovalMagenticManager +from v3.orchestration.human_approval_manager import HumanApprovalMagenticManager # Context variable to hold the current user ID -current_user_id: ContextVar[Optional[str]] = contextvars.ContextVar("current_user_id", default=None) +current_user_id: ContextVar[Optional[str]] = contextvars.ContextVar( + "current_user_id", default=None +) + class OrchestrationManager: """Manager for handling orchestration logic.""" @@ -43,13 +49,14 @@ def __init__(self): self.logger = self.__class__.logger @classmethod - async def init_orchestration(cls, agents: List, user_id: str = None)-> MagenticOrchestration: + async def init_orchestration( + cls, agents: List, user_id: str = None + ) -> MagenticOrchestration: """Main function to run the agents.""" # Custom execution settings that should work with Azure OpenAI execution_settings = OpenAIChatPromptExecutionSettings( - max_tokens=4000, - temperature=0.1 + max_tokens=4000, temperature=0.1 ) credential = SyncDefaultAzureCredential() @@ -57,7 +64,7 @@ async def init_orchestration(cls, agents: List, user_id: str = None)-> MagenticO def get_token(): token = credential.get_token("https://cognitiveservices.azure.com/.default") return token.token - + # 1. Create a Magentic orchestration with Azure OpenAI magentic_orchestration = MagenticOrchestration( members=agents, @@ -65,34 +72,48 @@ def get_token(): chat_completion_service=AzureChatCompletion( deployment_name=config.AZURE_OPENAI_DEPLOYMENT_NAME, endpoint=config.AZURE_OPENAI_ENDPOINT, - ad_token_provider=get_token # Use token provider function + ad_token_provider=get_token, # Use token provider function ), - execution_settings=execution_settings + execution_settings=execution_settings, ), agent_response_callback=cls._user_aware_agent_callback(user_id), - streaming_agent_response_callback=cls._user_aware_streaming_callback(user_id) + streaming_agent_response_callback=cls._user_aware_streaming_callback( + user_id + ), ) return magentic_orchestration - + @staticmethod def _user_aware_agent_callback(user_id: str): """Factory method that creates a callback with captured user_id""" + def callback(message: ChatMessageContent): return agent_response_callback(message, user_id) + return callback - + @staticmethod def _user_aware_streaming_callback(user_id: str): """Factory method that creates a streaming callback with captured user_id""" - async def callback(streaming_message: StreamingChatMessageContent, is_final: bool): - return await streaming_agent_response_callback(streaming_message, is_final, user_id) + + async def callback( + streaming_message: StreamingChatMessageContent, is_final: bool + ): + return await streaming_agent_response_callback( + streaming_message, is_final, user_id + ) + return callback - + @classmethod - async def get_current_or_new_orchestration(cls, user_id: str, team_config: TeamConfiguration, team_switched: bool) -> MagenticOrchestration: # add team_switched: bool parameter + async def get_current_or_new_orchestration( + cls, user_id: str, team_config: TeamConfiguration, team_switched: bool + ) -> MagenticOrchestration: # add team_switched: bool parameter """get existing orchestration instance.""" current_orchestration = orchestration_config.get_current_orchestration(user_id) - if current_orchestration is None or team_switched: # add check for team_switched flag + if ( + current_orchestration is None or team_switched + ): # add check for team_switched flag if current_orchestration is not None and team_switched: for agent in current_orchestration._members: if agent.name != "ProxyAgent": @@ -102,31 +123,35 @@ async def get_current_or_new_orchestration(cls, user_id: str, team_config: TeamC cls.logger.error("Error closing agent: %s", e) factory = MagenticAgentFactory() agents = await factory.get_agents(team_config_input=team_config) - orchestration_config.orchestrations[user_id] = await cls.init_orchestration(agents, user_id) + orchestration_config.orchestrations[user_id] = await cls.init_orchestration( + agents, user_id + ) return orchestration_config.get_current_orchestration(user_id) async def run_orchestration(self, user_id, input_task) -> None: - """ Run the orchestration with user input loop.""" + """Run the orchestration with user input loop.""" token = current_user_id.set(user_id) job_id = str(uuid.uuid4()) orchestration_config.approvals[job_id] = None - + magentic_orchestration = orchestration_config.get_current_orchestration(user_id) if magentic_orchestration is None: raise ValueError("Orchestration not initialized for user.") - + try: - if hasattr(magentic_orchestration, '_manager') and hasattr(magentic_orchestration._manager, 'current_user_id'): + if hasattr(magentic_orchestration, "_manager") and hasattr( + magentic_orchestration._manager, "current_user_id" + ): magentic_orchestration._manager.current_user_id = user_id self.logger.debug(f"DEBUG: Set user_id on manager = {user_id}") except Exception as e: self.logger.error(f"Error setting user_id on manager: {e}") - + runtime = InProcessRuntime() runtime.start() - + try: orchestration_result = await magentic_orchestration.invoke( @@ -141,19 +166,23 @@ async def run_orchestration(self, user_id, input_task) -> None: self.logger.info("=" * 50) # Send final result via WebSocket - await connection_config.send_status_update_async({ - "type": WebsocketMessageType.FINAL_RESULT_MESSAGE, - "data": { - "content": str(value), - "status": "completed", - "timestamp": asyncio.get_event_loop().time() - } - }, user_id, message_type=WebsocketMessageType.FINAL_RESULT_MESSAGE) + await connection_config.send_status_update_async( + { + "type": WebsocketMessageType.FINAL_RESULT_MESSAGE, + "data": { + "content": str(value), + "status": "completed", + "timestamp": asyncio.get_event_loop().time(), + }, + }, + user_id, + message_type=WebsocketMessageType.FINAL_RESULT_MESSAGE, + ) self.logger.info(f"Final result sent via WebSocket to user {user_id}") except Exception as e: self.logger.info(f"Error: {e}") self.logger.info(f"Error type: {type(e).__name__}") - if hasattr(e, '__dict__'): + if hasattr(e, "__dict__"): self.logger.info(f"Error attributes: {e.__dict__}") self.logger.info("=" * 50) @@ -162,4 +191,3 @@ async def run_orchestration(self, user_id, input_task) -> None: finally: await runtime.stop_when_idle() current_user_id.reset(token) - diff --git a/src/tests/agents/test_foundry_integration.py b/src/tests/agents/test_foundry_integration.py index 0d3325c06..d1febec71 100644 --- a/src/tests/agents/test_foundry_integration.py +++ b/src/tests/agents/test_foundry_integration.py @@ -274,4 +274,4 @@ async def test_multiple_capabilities_together(self): if __name__ == "__main__": """Run the tests directly for debugging.""" - pytest.main([__file__, "-v", "-s"]) \ No newline at end of file + pytest.main([__file__, "-v", "-s"])