1414from fastapi import WebSocket
1515from semantic_kernel .agents .orchestration .magentic import MagenticOrchestration
1616from semantic_kernel .connectors .ai .open_ai import (
17- AzureChatCompletion , OpenAIChatPromptExecutionSettings )
17+ AzureChatCompletion ,
18+ OpenAIChatPromptExecutionSettings ,
19+ )
20+
21+ from v3 .models .messages import WebsocketMessageType
1822
1923logger = logging .getLogger (__name__ )
2024
2125# Create a context variable to track current user
22- current_user_id : contextvars .ContextVar [Optional [str ]] = contextvars .ContextVar ('current_user_id' , default = None )
26+ current_user_id : contextvars .ContextVar [Optional [str ]] = contextvars .ContextVar (
27+ "current_user_id" , default = None
28+ )
29+
2330
2431class AzureConfig :
2532 """Azure OpenAI and authentication configuration."""
@@ -37,7 +44,7 @@ def ad_token_provider(self) -> str:
3744 token = self .credential .get_token (config .AZURE_COGNITIVE_SERVICES )
3845 return token .token
3946
40- async def create_chat_completion_service (self , use_reasoning_model : bool = False ):
47+ async def create_chat_completion_service (self , use_reasoning_model : bool = False ):
4148 """Create Azure Chat Completion service."""
4249 model_name = (
4350 self .reasoning_model if use_reasoning_model else self .standard_model
@@ -75,16 +82,19 @@ class OrchestrationConfig:
7582 """Configuration for orchestration settings."""
7683
7784 def __init__ (self ):
78- self .orchestrations : Dict [str , MagenticOrchestration ] = {} # user_id -> orchestration instance
79- self .plans : Dict [str , any ] = {} # plan_id -> plan details
80- self .approvals : Dict [str , bool ] = {} # m_plan_id -> approval status
81- self .sockets : Dict [str , WebSocket ] = {} # user_id -> WebSocket
82- self .clarifications : Dict [str , str ] = {} # m_plan_id -> clarification response
85+ self .orchestrations : Dict [str , MagenticOrchestration ] = (
86+ {}
87+ ) # user_id -> orchestration instance
88+ self .plans : Dict [str , any ] = {} # plan_id -> plan details
89+ self .approvals : Dict [str , bool ] = {} # m_plan_id -> approval status
90+ self .sockets : Dict [str , WebSocket ] = {} # user_id -> WebSocket
91+ self .clarifications : Dict [str , str ] = {} # m_plan_id -> clarification response
8392
8493 def get_current_orchestration (self , user_id : str ) -> MagenticOrchestration :
8594 """get existing orchestration instance."""
8695 return self .orchestrations .get (user_id , None )
87-
96+
97+
8898class ConnectionConfig :
8999 """Connection manager for WebSocket connections."""
90100
@@ -93,15 +103,19 @@ def __init__(self):
93103 # Map user_id to process_id for context-based messaging
94104 self .user_to_process : Dict [str , str ] = {}
95105
96- def add_connection (self , process_id : str , connection : WebSocket , user_id : str = None ):
106+ def add_connection (
107+ self , process_id : str , connection : WebSocket , user_id : str = None
108+ ):
97109 """Add a new connection."""
98110 # Close existing connection if it exists
99111 if process_id in self .connections :
100112 try :
101113 asyncio .create_task (self .connections [process_id ].close ())
102114 except Exception as e :
103- logger .error (f"Error closing existing connection for user { process_id } : { e } " )
104-
115+ logger .error (
116+ f"Error closing existing connection for user { process_id } : { e } "
117+ )
118+
105119 self .connections [process_id ] = connection
106120 # Map user to process for context-based messaging
107121 if user_id :
@@ -114,12 +128,18 @@ def add_connection(self, process_id: str, connection: WebSocket, user_id: str =
114128 try :
115129 asyncio .create_task (old_connection .close ())
116130 del self .connections [old_process_id ]
117- logger .info (f"Closed old connection { old_process_id } for user { user_id } " )
131+ logger .info (
132+ f"Closed old connection { old_process_id } for user { user_id } "
133+ )
118134 except Exception as e :
119- logger .error (f"Error closing old connection for user { user_id } : { e } " )
120-
135+ logger .error (
136+ f"Error closing old connection for user { user_id } : { e } "
137+ )
138+
121139 self .user_to_process [user_id ] = process_id
122- logger .info (f"WebSocket connection added for process: { process_id } (user: { user_id } )" )
140+ logger .info (
141+ f"WebSocket connection added for process: { process_id } (user: { user_id } )"
142+ )
123143 else :
124144 logger .info (f"WebSocket connection added for process: { process_id } " )
125145
@@ -128,7 +148,7 @@ def remove_connection(self, process_id):
128148 process_id = str (process_id )
129149 if process_id in self .connections :
130150 del self .connections [process_id ]
131-
151+
132152 # Remove from user mapping if exists
133153 for user_id , mapped_process_id in list (self .user_to_process .items ()):
134154 if mapped_process_id == process_id :
@@ -139,7 +159,7 @@ def remove_connection(self, process_id):
139159 def get_connection (self , process_id ):
140160 """Get a connection."""
141161 return self .connections .get (process_id )
142-
162+
143163 async def close_connection (self , process_id ):
144164 """Remove a connection."""
145165 connection = self .get_connection (process_id )
@@ -156,22 +176,29 @@ async def close_connection(self, process_id):
156176 self .remove_connection (process_id )
157177 logger .info ("Connection removed for batch ID: %s" , process_id )
158178
159- async def send_status_update_async (self , message : any , user_id : Optional [str ] = None ):
179+ async def send_status_update_async (
180+ self ,
181+ message : any ,
182+ user_id : Optional [str ] = None ,
183+ message_type : WebsocketMessageType = WebsocketMessageType .SYSTEM_MESSAGE ,
184+ ):
160185 """Send a status update to a specific client."""
161186 # If no process_id provided, get from context
162187 if user_id is None :
163188 user_id = current_user_id .get ()
164-
189+
165190 if not user_id :
166191 logger .warning ("No user_id available for WebSocket message" )
167192 return
168-
193+
169194 process_id = self .user_to_process .get (user_id )
170195 if not process_id :
171196 logger .warning ("No active WebSocket process found for user ID: %s" , user_id )
172- logger .debug (f"Available user mappings: { list (self .user_to_process .keys ())} " )
197+ logger .debug (
198+ f"Available user mappings: { list (self .user_to_process .keys ())} "
199+ )
173200 return
174-
201+
175202 connection = self .get_connection (process_id )
176203 if connection :
177204 try :
@@ -183,7 +210,9 @@ async def send_status_update_async(self, message: any, user_id: Optional[str] =
183210 # Clean up stale connection
184211 self .remove_connection (process_id )
185212 else :
186- logger .warning ("No connection found for process ID: %s (user: %s)" , process_id , user_id )
213+ logger .warning (
214+ "No connection found for process ID: %s (user: %s)" , process_id , user_id
215+ )
187216 # Clean up stale mapping
188217 if user_id in self .user_to_process :
189218 del self .user_to_process [user_id ]
@@ -201,6 +230,7 @@ def send_status_update(self, message: str, process_id: str):
201230 else :
202231 logger .warning ("No connection found for process ID: %s" , process_id )
203232
233+
204234class TeamConfig :
205235 """Team configuration for agents."""
206236
@@ -218,6 +248,7 @@ def get_current_team(self, user_id: str) -> TeamConfiguration:
218248 """Get the current team configuration."""
219249 return self .teams .get (user_id , None )
220250
251+
221252# Global config instances
222253azure_config = AzureConfig ()
223254mcp_config = MCPConfig ()
0 commit comments