Skip to content

Commit c251672

Browse files
committed
Add WebSocket message type support and refactor messaging
Introduces the WebsocketMessageType enum for consistent message typing across WebSocket communications. Refactors response handlers, connection management, and orchestration modules to utilize message_type for improved clarity and extensibility. Also removes an unused __init__.py file from common services.
1 parent dd3d046 commit c251672

File tree

6 files changed

+68
-33
lines changed

6 files changed

+68
-33
lines changed

src/backend/common/services/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/backend/v3/callbacks/response_handlers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
StreamingChatMessageContent)
1313
from v3.config.settings import connection_config, current_user_id
1414
from v3.models.messages import (AgentMessage, AgentMessageStreaming,
15-
AgentToolCall, AgentToolMessage)
15+
AgentToolCall, AgentToolMessage, WebsocketMessageType)
1616

1717

1818
def agent_response_callback(message: ChatMessageContent, user_id: str = None) -> None:
@@ -35,14 +35,16 @@ def agent_response_callback(message: ChatMessageContent, user_id: str = None) ->
3535
if item.content_type == 'function_call':
3636
tool_call = AgentToolCall(tool_name=item.name or "unknown_tool", arguments=item.arguments or {})
3737
final_message.tool_calls.append(tool_call)
38-
asyncio.create_task(connection_config.send_status_update_async(final_message, user_id))
38+
39+
asyncio.create_task(connection_config.send_status_update_async(final_message, user_id, message_type=WebsocketMessageType.AGENT_TOOL_MESSAGE))
3940
logging.info(f"Function call: {final_message}")
4041
elif message.items and message.items[0].content_type == 'function_result':
4142
# skip returning these results for now - agent will return in a later message
4243
pass
4344
else:
4445
final_message = AgentMessage(agent_name=agent_name, timestamp=time.time() or "", content=message.content or "")
45-
asyncio.create_task(connection_config.send_status_update_async(final_message, user_id))
46+
47+
asyncio.create_task(connection_config.send_status_update_async(final_message, user_id, message_type=WebsocketMessageType.AGENT_MESSAGE))
4648
logging.info(f"{role.capitalize()} message: {final_message}")
4749
except Exception as e:
4850
logging.error(f"Response_callback: Error sending WebSocket message: {e}")

src/backend/v3/config/settings.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,19 @@
1414
from fastapi import WebSocket
1515
from semantic_kernel.agents.orchestration.magentic import MagenticOrchestration
1616
from semantic_kernel.connectors.ai.open_ai import (
17-
AzureChatCompletion, OpenAIChatPromptExecutionSettings)
17+
AzureChatCompletion,
18+
OpenAIChatPromptExecutionSettings,
19+
)
20+
21+
from v3.models.messages import WebsocketMessageType
1822

1923
logger = logging.getLogger(__name__)
2024

2125
# Create a context variable to track current user
22-
current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar('current_user_id', default=None)
26+
current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
27+
"current_user_id", default=None
28+
)
29+
2330

2431
class AzureConfig:
2532
"""Azure OpenAI and authentication configuration."""
@@ -37,7 +44,7 @@ def ad_token_provider(self) -> str:
3744
token = self.credential.get_token(config.AZURE_COGNITIVE_SERVICES)
3845
return token.token
3946

40-
async def create_chat_completion_service(self, use_reasoning_model: bool=False):
47+
async def create_chat_completion_service(self, use_reasoning_model: bool = False):
4148
"""Create Azure Chat Completion service."""
4249
model_name = (
4350
self.reasoning_model if use_reasoning_model else self.standard_model
@@ -75,16 +82,19 @@ class OrchestrationConfig:
7582
"""Configuration for orchestration settings."""
7683

7784
def __init__(self):
78-
self.orchestrations: Dict[str, MagenticOrchestration] = {} # user_id -> orchestration instance
79-
self.plans: Dict[str, any] = {} # plan_id -> plan details
80-
self.approvals: Dict[str, bool] = {} # m_plan_id -> approval status
81-
self.sockets: Dict[str, WebSocket] = {} # user_id -> WebSocket
82-
self.clarifications: Dict[str, str] = {} # m_plan_id -> clarification response
85+
self.orchestrations: Dict[str, MagenticOrchestration] = (
86+
{}
87+
) # user_id -> orchestration instance
88+
self.plans: Dict[str, any] = {} # plan_id -> plan details
89+
self.approvals: Dict[str, bool] = {} # m_plan_id -> approval status
90+
self.sockets: Dict[str, WebSocket] = {} # user_id -> WebSocket
91+
self.clarifications: Dict[str, str] = {} # m_plan_id -> clarification response
8392

8493
def get_current_orchestration(self, user_id: str) -> MagenticOrchestration:
8594
"""get existing orchestration instance."""
8695
return self.orchestrations.get(user_id, None)
87-
96+
97+
8898
class ConnectionConfig:
8999
"""Connection manager for WebSocket connections."""
90100

@@ -93,15 +103,19 @@ def __init__(self):
93103
# Map user_id to process_id for context-based messaging
94104
self.user_to_process: Dict[str, str] = {}
95105

96-
def add_connection(self, process_id: str, connection: WebSocket, user_id: str = None):
106+
def add_connection(
107+
self, process_id: str, connection: WebSocket, user_id: str = None
108+
):
97109
"""Add a new connection."""
98110
# Close existing connection if it exists
99111
if process_id in self.connections:
100112
try:
101113
asyncio.create_task(self.connections[process_id].close())
102114
except Exception as e:
103-
logger.error(f"Error closing existing connection for user {process_id}: {e}")
104-
115+
logger.error(
116+
f"Error closing existing connection for user {process_id}: {e}"
117+
)
118+
105119
self.connections[process_id] = connection
106120
# Map user to process for context-based messaging
107121
if user_id:
@@ -114,12 +128,18 @@ def add_connection(self, process_id: str, connection: WebSocket, user_id: str =
114128
try:
115129
asyncio.create_task(old_connection.close())
116130
del self.connections[old_process_id]
117-
logger.info(f"Closed old connection {old_process_id} for user {user_id}")
131+
logger.info(
132+
f"Closed old connection {old_process_id} for user {user_id}"
133+
)
118134
except Exception as e:
119-
logger.error(f"Error closing old connection for user {user_id}: {e}")
120-
135+
logger.error(
136+
f"Error closing old connection for user {user_id}: {e}"
137+
)
138+
121139
self.user_to_process[user_id] = process_id
122-
logger.info(f"WebSocket connection added for process: {process_id} (user: {user_id})")
140+
logger.info(
141+
f"WebSocket connection added for process: {process_id} (user: {user_id})"
142+
)
123143
else:
124144
logger.info(f"WebSocket connection added for process: {process_id}")
125145

@@ -128,7 +148,7 @@ def remove_connection(self, process_id):
128148
process_id = str(process_id)
129149
if process_id in self.connections:
130150
del self.connections[process_id]
131-
151+
132152
# Remove from user mapping if exists
133153
for user_id, mapped_process_id in list(self.user_to_process.items()):
134154
if mapped_process_id == process_id:
@@ -139,7 +159,7 @@ def remove_connection(self, process_id):
139159
def get_connection(self, process_id):
140160
"""Get a connection."""
141161
return self.connections.get(process_id)
142-
162+
143163
async def close_connection(self, process_id):
144164
"""Remove a connection."""
145165
connection = self.get_connection(process_id)
@@ -156,22 +176,29 @@ async def close_connection(self, process_id):
156176
self.remove_connection(process_id)
157177
logger.info("Connection removed for batch ID: %s", process_id)
158178

159-
async def send_status_update_async(self, message: any, user_id: Optional[str] = None):
179+
async def send_status_update_async(
180+
self,
181+
message: any,
182+
user_id: Optional[str] = None,
183+
message_type: WebsocketMessageType = WebsocketMessageType.SYSTEM_MESSAGE,
184+
):
160185
"""Send a status update to a specific client."""
161186
# If no process_id provided, get from context
162187
if user_id is None:
163188
user_id = current_user_id.get()
164-
189+
165190
if not user_id:
166191
logger.warning("No user_id available for WebSocket message")
167192
return
168-
193+
169194
process_id = self.user_to_process.get(user_id)
170195
if not process_id:
171196
logger.warning("No active WebSocket process found for user ID: %s", user_id)
172-
logger.debug(f"Available user mappings: {list(self.user_to_process.keys())}")
197+
logger.debug(
198+
f"Available user mappings: {list(self.user_to_process.keys())}"
199+
)
173200
return
174-
201+
175202
connection = self.get_connection(process_id)
176203
if connection:
177204
try:
@@ -183,7 +210,9 @@ async def send_status_update_async(self, message: any, user_id: Optional[str] =
183210
# Clean up stale connection
184211
self.remove_connection(process_id)
185212
else:
186-
logger.warning("No connection found for process ID: %s (user: %s)", process_id, user_id)
213+
logger.warning(
214+
"No connection found for process ID: %s (user: %s)", process_id, user_id
215+
)
187216
# Clean up stale mapping
188217
if user_id in self.user_to_process:
189218
del self.user_to_process[user_id]
@@ -201,6 +230,7 @@ def send_status_update(self, message: str, process_id: str):
201230
else:
202231
logger.warning("No connection found for process ID: %s", process_id)
203232

233+
204234
class TeamConfig:
205235
"""Team configuration for agents."""
206236

@@ -218,6 +248,7 @@ def get_current_team(self, user_id: str) -> TeamConfiguration:
218248
"""Get the current team configuration."""
219249
return self.teams.get(user_id, None)
220250

251+
221252
# Global config instances
222253
azure_config = AzureConfig()
223254
mcp_config = MCPConfig()

src/backend/v3/models/messages.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ class ApprovalRequest(KernelBaseModel):
141141

142142

143143
class WebsocketMessageType(str, Enum):
144+
"""Types of WebSocket messages."""
145+
SYSTEM_MESSAGE = "system_message"
144146
AGENT_MESSAGE = "agent_message"
145147
AGENT_STREAM_START = "agent_stream_start"
146148
AGENT_STREAM_END = "agent_stream_end"

src/backend/v3/orchestration/human_approval_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ async def plan(self, magentic_context: MagenticContext) -> Any:
8787
# Send the approval request to the user's WebSocket
8888
# The user_id will be automatically retrieved from context
8989
await connection_config.send_status_update_async({
90-
"type": "plan_approval_request",
90+
"type": messages.WebsocketMessageType.PLAN_APPROVAL_REQUEST,
9191
"data": approval_message
92-
})
92+
}, user_id=current_user_id.get(), message_type=messages.WebsocketMessageType.PLAN_APPROVAL_REQUEST)
9393

9494
# Wait for user approval
9595
approval_response = await self._wait_for_user_approval(approval_message.plan.id)
@@ -100,9 +100,9 @@ async def plan(self, magentic_context: MagenticContext) -> Any:
100100
else:
101101
print("Plan execution cancelled by user")
102102
await connection_config.send_status_update_async({
103-
"type": "plan_approval_response",
103+
"type": messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE,
104104
"data": approval_response
105-
})
105+
}, user_id=current_user_id.get(), message_type=messages.WebsocketMessageType.PLAN_APPROVAL_RESPONSE)
106106
raise Exception("Plan execution cancelled by user")
107107
# return ChatMessageContent(
108108
# role="assistant",

src/backend/v3/orchestration/orchestration_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
AzureChatCompletion, OpenAIChatPromptExecutionSettings)
1818
from semantic_kernel.contents import (ChatMessageContent,
1919
StreamingChatMessageContent)
20+
from v3.models.messages import WebsocketMessageType
2021
from v3.callbacks.response_handlers import (agent_response_callback,
2122
streaming_agent_response_callback)
2223
from v3.config.settings import (config, connection_config, current_user_id,
@@ -140,7 +141,7 @@ async def run_orchestration(self, user_id, input_task) -> None:
140141
"status": "completed",
141142
"timestamp": str(uuid.uuid4()) # or use actual timestamp
142143
}
143-
}, user_id)
144+
}, user_id, message_type=WebsocketMessageType.FINAL_RESULT_MESSAGE)
144145
print(f"Final result sent via WebSocket to user {user_id}")
145146
except Exception as e:
146147
print(f"Error: {e}")

0 commit comments

Comments
 (0)