1313
1414if TYPE_CHECKING :
1515 from tornado .websocket import WebSocketHandler
16+
17+
1618class YjsClient :
1719 """Data model that represents all data associated
1820 with a user connecting to a YDoc through JupyterLab."""
1921
20- websocket : WebSocketHandler | None
22+ websocket : WebSocketHandler
2123 """The Tornado WebSocketHandler handling the WS connection to this client."""
2224 id : str
2325 """UUIDv4 string that uniquely identifies this client."""
@@ -27,21 +29,23 @@ class YjsClient:
2729 _synced : bool
2830
2931 def __init__ (self , websocket ):
30- self .websocket : WebSocketHandler | None = websocket
31- self .id : str = str (uuid .uuid4 ())
32- self ._synced : bool = False
32+ self .websocket = websocket
33+ self .id = str (uuid .uuid4 ())
34+ self ._synced = False
3335 self .last_modified = datetime .now (timezone .utc )
3436
37+
3538 @property
36- def synced (self ):
39+ def synced (self ) -> bool :
3740 """
3841 Indicates whether the initial Client SS1 + Server SS2 handshake has been
3942 completed.
4043 """
4144 return self ._synced
4245
46+
4347 @synced .setter
44- def synced (self , v : bool ):
48+ def synced (self , v : bool ) -> None :
4549 self ._synced = v
4650 self .last_modified = datetime .now (timezone .utc )
4751
@@ -70,6 +74,11 @@ class YjsClientGroup:
7074 """The poll time interval used while auto removing desynced clients"""
7175 desynced_timeout_seconds : int
7276 """The max time period in seconds that a desynced client does not become synced before get auto removed from desynced dict"""
77+ _stopped : bool
78+ """
79+ Whether the client group has been stopped. If `True`, this client group will
80+ ignore all future calls to `add()`.
81+ """
7382
7483 def __init__ (self , * , room_id : str , log : Logger , loop : asyncio .AbstractEventLoop , poll_interval_seconds : int = 60 , desynced_timeout_seconds : int = 120 ):
7584 self .room_id = room_id
@@ -80,9 +89,20 @@ def __init__(self, *, room_id: str, log: Logger, loop: asyncio.AbstractEventLoop
8089 # self.loop.create_task(self._clean_desynced())
8190 self ._poll_interval_seconds = poll_interval_seconds
8291 self .desynced_timeout_seconds = desynced_timeout_seconds
92+ self ._stopped = False
8393
84- def add (self , websocket : WebSocketHandler ) -> str :
85- """Adds a pending client to the group. Returns a client ID."""
94+ def add (self , websocket : WebSocketHandler ) -> str | None :
95+ """
96+ Adds a new Websocket as a client to the group.
97+
98+ Returns a client ID if the client group is active.
99+
100+ Returns `None` if the client group is stopped, which is set after
101+ `stop()` is called.
102+ """
103+ if self ._stopped :
104+ return
105+
86106 client = YjsClient (websocket )
87107 self .desynced [client .id ] = client
88108 return client .id
@@ -101,13 +121,16 @@ def mark_desynced(self, client_id: str) -> None:
101121
102122 def remove (self , client_id : str ) -> None :
103123 """Removes a client from the group."""
104- if client := self .desynced .pop (client_id , None ) is None :
124+ client = self .desynced .pop (client_id , None )
125+ if not client :
105126 client = self .synced .pop (client_id , None )
106- if client and client .websocket and client .websocket .ws_connection :
107- try :
108- client .websocket .close ()
109- except Exception as e :
110- self .log .exception (f"An exception occurred when remove client '{ client_id } ' for room '{ self .room_id } ': { e } " )
127+ if not client :
128+ return
129+
130+ try :
131+ client .websocket .close ()
132+ except Exception as e :
133+ self .log .exception (f"An exception occurred when remove client '{ client_id } ' for room '{ self .room_id } ': { e } " )
111134
112135 def get (self , client_id : str ) -> YjsClient :
113136 """
@@ -128,10 +151,12 @@ def get_all(self, synced_only: bool = True) -> list[YjsClient]:
128151 Returns a list of all synced clients.
129152 Set synced_only=False to also get desynced clients.
130153 """
131- if synced_only :
132- return list (client for client in self .synced .values () if client .websocket and client .websocket .ws_connection )
133- return list (client for client in self .desynced .values () if client .websocket and client .websocket .ws_connection )
134-
154+ all_clients = [c for c in self .synced .values ()]
155+ if not synced_only :
156+ all_clients += [c for c in self .desynced .values ()]
157+
158+ return all_clients
159+
135160 def is_empty (self ) -> bool :
136161 """Returns whether the client group is empty."""
137162 return len (self .synced ) == 0 and len (self .desynced ) == 0
@@ -150,4 +175,16 @@ async def _clean_desynced(self) -> None:
150175 self .remove (client_id )
151176 except asyncio .CancelledError :
152177 break
153-
178+
179+ def stop (self ):
180+ """
181+ Closes all Websocket connections with close code 1001 (server
182+ shutting down) and ignores any future calls to `add()`.
183+ """
184+ # Close all Websocket connections
185+ for client in self .get_all (synced_only = False ):
186+ client .websocket .close (code = 1001 )
187+
188+ # Set `_stopped` to `True` to ignore future calls to `add()`
189+ self ._stopped = True
190+
0 commit comments