Skip to content

Commit 67733d3

Browse files
committed
Websocket context solution - draft
1 parent ba12990 commit 67733d3

File tree

9 files changed

+268
-149
lines changed

9 files changed

+268
-149
lines changed

src/backend/app_kernel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -667,8 +667,7 @@ async def get_plans(
667667
"UserIdNotFound", {"status_code": 400, "detail": "no user"}
668668
)
669669
raise HTTPException(status_code=400, detail="no user")
670-
671-
await connection_config.send_status_update_async("Test message from get_plans", user_id)
670+
672671

673672
#### <To do: Francia> Replace the following with code to get plan run history from the database
674673

src/backend/common/models/messages_kernel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ class StartingTask(KernelBaseModel):
178178
creator: str
179179
logo: str
180180

181+
class TeamSelectionRequest(KernelBaseModel):
182+
"""Request model for team selection."""
183+
team_id: str
184+
session_id: Optional[str] = None
181185

182186
class TeamConfiguration(BaseDataModel):
183187
"""Represents a team configuration stored in the database."""

src/backend/v3/api/router.py

Lines changed: 41 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextvars
23
import json
34
import logging
45
import uuid
@@ -9,28 +10,21 @@
910
from common.config.app_config import config
1011
from common.database.database_factory import DatabaseFactory
1112
from common.models.messages_kernel import (GeneratePlanRequest, InputTask,
12-
Plan, PlanStatus)
13+
TeamSelectionRequest)
1314
from common.utils.event_utils import track_event_if_configured
1415
from common.utils.utils_kernel import rai_success, rai_validate_team_config
1516
from fastapi import (APIRouter, BackgroundTasks, Depends, FastAPI, File,
1617
HTTPException, Request, UploadFile, WebSocket,
1718
WebSocketDisconnect)
1819
from kernel_agents.agent_factory import AgentFactory
19-
from pydantic import BaseModel
2020
from semantic_kernel.agents.runtime import InProcessRuntime
2121
from v3.common.services.team_service import TeamService
22-
from v3.config.settings import connection_config, team_config
22+
from v3.config.settings import connection_config, current_user_id, team_config
2323
from v3.orchestration.orchestration_manager import OrchestrationManager
2424

2525
router = APIRouter()
2626
logger = logging.getLogger(__name__)
2727

28-
class TeamSelectionRequest(BaseModel):
29-
"""Request model for team selection."""
30-
team_id: str
31-
session_id: Optional[str] = None
32-
33-
3428
app_v3 = APIRouter(
3529
prefix="/api/v3",
3630
responses={404: {"description": "Not found"}},
@@ -50,32 +44,38 @@ async def start_comms(websocket: WebSocket, process_id: str):
5044
authenticated_user = get_authenticated_user_details(request_headers=headers)
5145
user_id = authenticated_user.get("user_principal_id")
5246
if not user_id:
53-
user_id = f"anonymous_{process_id}"
47+
user_id = "00000000-0000-0000-0000-000000000000"
5448
except Exception as e:
5549
logging.warning(f"Could not extract user from WebSocket headers: {e}")
56-
user_id = f"anonymous_{user_id}"
50+
user_id = "00000000-0000-0000-0000-000000000000"
5751

58-
# Add to the connection manager for backend updates
52+
current_user_id.set(user_id)
5953

60-
connection_config.add_connection(user_id, websocket)
61-
track_event_if_configured("WebSocketConnectionAccepted", {"process_id": "user_id"})
54+
# Add to the connection manager for backend updates
55+
connection_config.add_connection(process_id=process_id, connection=websocket, user_id=user_id)
56+
track_event_if_configured("WebSocketConnectionAccepted", {"process_id": process_id, "user_id": user_id})
6257

6358
# Keep the connection open - FastAPI will close the connection if this returns
64-
while True:
65-
# no expectation that we will receive anything from the client but this keeps
66-
# the connection open and does not take cpu cycle
67-
try:
68-
await websocket.receive_text()
69-
except asyncio.TimeoutError:
70-
pass
71-
72-
except WebSocketDisconnect:
73-
track_event_if_configured("WebSocketDisconnect", {"process_id": process_id})
74-
logging.info(f"Client disconnected from batch {process_id}")
75-
await connection_config.close_connection(user_id)
76-
except Exception as e:
77-
logging.error("Error in WebSocket connection", error=str(e))
78-
await connection_config.close_connection(user_id)
59+
try:
60+
# Keep the connection open - FastAPI will close the connection if this returns
61+
while True:
62+
# no expectation that we will receive anything from the client but this keeps
63+
# the connection open and does not take cpu cycle
64+
try:
65+
message = await websocket.receive_text()
66+
logging.debug(f"Received WebSocket message from {user_id}: {message}")
67+
except asyncio.TimeoutError:
68+
pass
69+
except WebSocketDisconnect:
70+
track_event_if_configured("WebSocketDisconnect", {"process_id": process_id, "user_id": user_id})
71+
logging.info(f"Client disconnected from batch {process_id}")
72+
break
73+
except Exception as e:
74+
# Fixed logging syntax - removed the error= parameter
75+
logging.error(f"Error in WebSocket connection: {str(e)}")
76+
finally:
77+
# Always clean up the connection
78+
await connection_config.close_connection(user_id)
7979

8080
@app_v3.get("/init_team")
8181
async def init_team(
@@ -112,7 +112,7 @@ async def init_team(
112112
)
113113

114114
# Set as current team in memory
115-
team_config.set_current_team(user_id=user_id, team_config=team_configuration)
115+
team_config.set_current_team(user_id=user_id, team_configuration=team_configuration)
116116

117117
# Initialize agent team for this user session
118118
await OrchestrationManager.get_current_or_new_orchestration(user_id=user_id, team_config=team_configuration)
@@ -227,8 +227,16 @@ async def process_request(background_tasks: BackgroundTasks, input_task: InputTa
227227
input_task.session_id = str(uuid.uuid4())
228228

229229
try:
230-
background_tasks.add_task(OrchestrationManager.run_orchestration, user_id, input_task)
231-
#await connection_config.send_status_update_async("Test message from process_request", user_id)
230+
current_user_id.set(user_id) # Set context
231+
current_context = contextvars.copy_context() # Capture context
232+
# background_tasks.add_task(
233+
# lambda: current_context.run(lambda:OrchestrationManager().run_orchestration, user_id, input_task)
234+
# )
235+
236+
async def run_with_context():
237+
return await current_context.run(OrchestrationManager().run_orchestration, user_id, input_task)
238+
239+
background_tasks.add_task(run_with_context)
232240

233241
return {
234242
"status": "Request started successfully",
@@ -788,46 +796,4 @@ async def get_search_indexes_endpoint(request: Request):
788796
return {"search_summary": summary}
789797
except Exception as e:
790798
logging.error(f"Error retrieving search indexes: {str(e)}")
791-
raise HTTPException(status_code=500, detail="Internal server error occurred")
792-
793-
794-
# @app_v3.websocket("/socket/{process_id}")
795-
# async def process_outputs(websocket: WebSocket, process_id: str):
796-
# """ Web-Socket endpoint for real-time process status updates. """
797-
798-
# # Always accept the WebSocket connection first
799-
# await websocket.accept()
800-
801-
# user_id = None
802-
# try:
803-
# # WebSocket headers are different, try to get user info
804-
# headers = dict(websocket.headers)
805-
# authenticated_user = get_authenticated_user_details(request_headers=headers)
806-
# user_id = authenticated_user.get("user_principal_id")
807-
# if not user_id:
808-
# user_id = f"anonymous_{process_id}"
809-
# except Exception as e:
810-
# logger.warning(f"Could not extract user from WebSocket headers: {e}")
811-
# # user_id = f"anonymous_{user_id}"
812-
813-
# # Add to the connection manager for backend updates
814-
815-
# connection_config.add_connection(user_id, websocket)
816-
# track_event_if_configured("WebSocketConnectionAccepted", {"process_id": user_id})
817-
818-
# # Keep the connection open - FastAPI will close the connection if this returns
819-
# while True:
820-
# # no expectation that we will receive anything from the client but this keeps
821-
# # the connection open and does not take cpu cycle
822-
# try:
823-
# await websocket.receive_text()
824-
# except asyncio.TimeoutError:
825-
# pass
826-
827-
# except WebSocketDisconnect:
828-
# track_event_if_configured("WebSocketDisconnect", {"process_id": user_id})
829-
# logger.info(f"Client disconnected from batch {user_id}")
830-
# await connection_config.close_connection(user_id)
831-
# except Exception as e:
832-
# logger.error("Error in WebSocket connection", error=str(e))
833-
# await connection_config.close_connection(user_id)
799+
raise HTTPException(status_code=500, detail="Internal server error occurred")

src/backend/v3/callbacks/response_handlers.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,23 @@
22
Enhanced response callbacks for employee onboarding agent system.
33
Provides detailed monitoring and response handling for different agent types.
44
"""
5-
5+
import asyncio
66
import sys
77

88
from semantic_kernel.contents import (ChatMessageContent,
99
StreamingChatMessageContent)
10+
from v3.config.settings import connection_config, current_user_id
1011

1112
coderagent = False
1213

14+
# Module-level variable to store current user_id
15+
_current_user_id: str = None
16+
17+
def set_user_context(user_id: str):
18+
"""Set the user context for callbacks in this module."""
19+
global _current_user_id
20+
_current_user_id = user_id
21+
1322
def agent_response_callback(message: ChatMessageContent) -> None:
1423
"""Observer function to print detailed information about streaming messages."""
1524
global coderagent
@@ -31,8 +40,19 @@ def agent_response_callback(message: ChatMessageContent) -> None:
3140
return
3241
elif coderagent == True:
3342
coderagent = False
43+
3444
role = getattr(message, 'role', 'unknown')
3545

46+
# Send to WebSocket
47+
if _current_user_id:
48+
try:
49+
asyncio.create_task(connection_config.send_status_update_async({
50+
"type": "agent_message",
51+
"data": {"agent_name": agent_name, "content": message.content, "role": role}
52+
}, _current_user_id))
53+
except Exception as e:
54+
print(f"Error sending WebSocket message: {e}")
55+
3656
print(f"\n🧠 **{agent_name}** [{message_type}] (role: {role})")
3757
print("-" * (len(agent_name) + len(message_type) + 10))
3858
if message.items[-1].content_type == 'function_call':

0 commit comments

Comments
 (0)