1
1
from typing import Dict
2
- from fastapi import WebSocket , WebSocketDisconnect
3
2
import uuid
4
3
import asyncio
5
4
from app .schemas .messages import CollaboratorConnectMessage , CollaboratorDisconnectMessage , DisplayMessage
5
+ from app .core .errors import SessionNotFoundError , UserNotFoundError , InvalidUserIDsError
6
6
7
7
class ConnectionManager :
8
8
def __init__ (self ):
@@ -11,7 +11,7 @@ def __init__(self):
11
11
12
12
def init_session (self , user_ids : list [str ]) -> str :
13
13
if len (user_ids ) != 2 :
14
- raise ValueError ( "user_ids must contain exactly two user IDs." )
14
+ raise InvalidUserIDsError ( )
15
15
session_id = self .generate_uuid (user_ids [0 ], user_ids [1 ])
16
16
17
17
self .active_connections [session_id ] = {}
@@ -28,11 +28,11 @@ def _get_collaborator_q(self, session_id, user_id):
28
28
29
29
async def on_connect (self , session_id : str , user_id : str , out_q : asyncio .Queue ):
30
30
if session_id not in self .active_connections :
31
- raise ValueError ( "Session not initialized" )
31
+ raise SessionNotFoundError ( )
32
32
33
33
if user_id not in self .active_connections [session_id ]:
34
- raise ValueError ( "user_id not available in this session" )
35
-
34
+ raise UserNotFoundError ( )
35
+
36
36
self .active_connections [session_id ][user_id ] = out_q
37
37
38
38
collaborator_q = self ._get_collaborator_q (session_id , user_id )
@@ -48,11 +48,11 @@ async def on_connect(self, session_id: str, user_id: str, out_q: asyncio.Queue):
48
48
49
49
async def on_disconnect (self , session_id : str , user_id : str ):
50
50
if session_id not in self .active_connections :
51
- raise ValueError ( "Session not initialized" )
51
+ raise SessionNotFoundError ( )
52
52
53
53
if user_id not in self .active_connections [session_id ]:
54
- raise ValueError ( "user_id not available in this session" )
55
-
54
+ raise UserNotFoundError ( )
55
+
56
56
self .active_connections [session_id ][user_id ] = None
57
57
58
58
collaborator_q = self ._get_collaborator_q (session_id , user_id )
@@ -71,70 +71,9 @@ async def on_message(self, session_id: str, user_id: str, msg: str):
71
71
await collaborator_q .put (DisplayMessage (msg = msg ))
72
72
73
73
def generate_uuid (self , user1 : str , user2 : str ) -> str :
74
- namespace = uuid .NAMESPACE_DNS # or uuid.NAMESPACE_URL, etc.
74
+ namespace = uuid .NAMESPACE_DNS
75
75
seed = '' .join (sorted ([user1 , user2 ]))
76
76
return str (uuid .uuid5 (namespace , seed ))
77
77
78
- # async def connect(self, session_id: str, user_id: str, ws: WebSocket):
79
- # if session_id not in self.active_connections:
80
- # raise ValueError("Session not initialized.")
81
-
82
- # if user_id not in self.active_connections[session_id]:
83
- # raise ValueError("user_id not recognized in this session.")
84
-
85
- # await ws.accept()
86
- # self.active_connections[session_id][user_id] = ws
87
-
88
- # # Notify any waiting coroutines that a user has connected
89
- # async with self.condition:
90
- # self.condition.notify_all()
91
-
92
- # return None
93
-
94
- # async def await_collaborator_ws(self, session_id: str, user_id: str) -> WebSocket | None:
95
- # if session_id not in self.active_connections:
96
- # raise ValueError("Session not initialized.")
97
-
98
- # session_data = self.active_connections[session_id]
99
- # collaborator_id = next((uid for uid in session_data if uid != user_id), None)
100
-
101
- # user_ws = session_data[user_id]
102
- # collaborator_ws = session_data[collaborator_id]
103
-
104
- # try:
105
- # if collaborator_ws is None:
106
- # async with self.condition:
107
- # await user_ws.send_text(f"Waiting for collaborator {collaborator_id} to connect...")
108
- # await self.condition.wait_for(lambda: session_data[collaborator_id] is not None)
109
- # await user_ws.send_text(f"Collaborator {collaborator_id} connected!")
110
- # collaborator_ws = session_data[collaborator_id]
111
- # except Exception as e:
112
- # print(f"Error while waiting for collaborator: {e}")
113
- # return None
114
-
115
- # return session_data[collaborator_id]
116
-
117
- # async def disconnect(self, session_id: str, user_id: str):
118
- # if session_id not in self.active_connections:
119
- # return
120
-
121
- # session_data = self.active_connections[session_id]
122
-
123
- # ws = session_data[user_id]
124
- # session_data[user_id] = None
125
-
126
- # # Notify any waiting coroutines that a user has disconnected
127
- # async with self.condition:
128
- # self.condition.notify_all()
129
-
130
- # if ws:
131
- # try:
132
- # await ws.close()
133
- # except Exception:
134
- # pass
135
-
136
- # def get_active_connections(self, session_id: str):
137
- # return self.active_connections.get(session_id, {})
138
-
139
78
140
79
connection_manager = ConnectionManager ()
0 commit comments