55"""
66
77from __future__ import annotations
8+ from datetime import timedelta , timezone
9+ import datetime
10+ from logging import Logger
811from typing import TYPE_CHECKING
912from dataclasses import dataclass
1013import uuid
14+ import asyncio
1115
1216if TYPE_CHECKING :
1317 from tornado .websocket import WebSocketHandler
14-
15- @dataclass
1618class YjsClient :
1719 """Data model that represents all data associated
1820 with a user connecting to a YDoc through JupyterLab."""
1921
20- websocket : WebSocketHandler | None = None
21- """
22- The Tornado WebSocketHandler handling the WS connection to this
23- client.
24-
25- TODO: make this required
26- """
27-
28- id : str = str (uuid .uuid4 ())
22+ websocket : WebSocketHandler | None
23+ """The Tornado WebSocketHandler handling the WS connection to this client."""
24+ id : str
2925 """UUIDv4 string that uniquely identifies this client."""
30-
31- synced : bool = False
26+ synced : bool
3227 """Indicates whether the SS1 + SS2 handshake has been completed."""
28+ last_modified : datetime
29+ """Indicates the last modified time when synced state is modified"""
30+
31+ def __init__ (self , websocket ):
32+ self .websocket : WebSocketHandler | None = websocket
33+ self .id : str = str (uuid .uuid4 ())
34+ self .synced : bool = False
35+ self .last_modified = datetime .now (timezone .utc )
36+
37+
38+ @synced .setter
39+ def synced (self , v : bool ):
40+ self .synced = v
41+ self .last_modified = datetime .now (timezone .utc )
3342
3443class YjsClientGroup :
3544 """
@@ -39,42 +48,101 @@ class YjsClientGroup:
3948 New clients start as desynced. Consumers should call mark_synced() to mark a
4049 new client as synced once the SS1 + SS2 handshake is complete.
4150
42- TODO: Automatically removes desynced clients if they do not become synced after
51+ Automatically removes desynced clients if they do not become synced after
4352 a certain timeout.
4453 """
54+ room_id : str
55+ """Room Id for associated YRoom"""
4556 synced : dict [str , YjsClient ]
57+ """A dict of client_id and synced YjsClient mapping"""
4658 desynced : dict [str , YjsClient ]
59+ """A dict of client_id and desynced YjsClient mapping"""
60+ log : Logger
61+ """Log object"""
62+ loop : asyncio .AbstractEventLoop
63+ """Event loop"""
64+ poll_interval_seconds : int
65+ """The poll time interval used while auto removing desynced clients"""
66+ desynced_timeout_seconds : int
67+ """The max time period in seconds that a desynced client does not become synced before get auto removed from desynced dict"""
68+
69+ def __init__ (self , * , room_id : str , log : Logger , loop : asyncio .AbstractEventLoop , poll_interval_seconds : int = 60 , desynced_timeout_seconds : int = 120 ):
70+ self .room_id = room_id
71+ self .synced : dict [str , YjsClient ] = {}
72+ self .desynced : dict [str , YjsClient ] = {}
73+ self .log = log
74+ self .loop = loop
75+ self .loop .create_task (self ._clean_desynced ())
76+ self .poll_interval_seconds = poll_interval_seconds
77+ self .desynced_timeout_seconds = desynced_timeout_seconds
4778
4879 def add (self , websocket : WebSocketHandler ) -> str :
4980 """Adds a pending client to the group. Returns a client ID."""
50- return ""
81+ client = YjsClient (websocket )
82+ self .desynced [client .id ] = client
83+ return client .id
5184
5285 def mark_synced (self , client_id : str ) -> None :
5386 """Marks a client as synced."""
54- return
87+ if client := self .desynced .pop (client_id , None ):
88+ client .synced = True
89+ self .synced [client .id ] = client
5590
5691 def mark_desynced (self , client_id : str ) -> None :
5792 """Marks a client as desynced."""
58- return
93+ if client := self .synced .pop (client_id , None ):
94+ client .synced = False
95+ self .desynced [client .id ] = client
5996
6097 def remove (self , client_id : str ) -> None :
6198 """Removes a client from the group."""
62- return
99+ if client := self .desynced .pop (client_id , None ) is None :
100+ client = self .synced .pop (client_id , None )
101+ if client and client .websocket and client .websocket .ws_connection :
102+ try :
103+ client .websocket .close ()
104+ except Exception as e :
105+ self .log .exception (f"An exception occurred when remove client '{ client_id } ' for room '{ self .room_id } ': { e } " )
63106
64- def get (self , client_id : str , synced_only : bool = True ) -> YjsClient :
107+ def get (self , client_id : str ) -> YjsClient :
65108 """
66109 Gets a client from its ID.
67- Set synced_only=False to also get desynced clients.
68110 """
69- return YjsClient ()
111+ if client_id in self .desynced :
112+ client = self .desynced [client_id ]
113+ if client_id in self .synced :
114+ client = self .synced [client_id ]
115+ if client .websocket and client .websocket .ws_connection :
116+ return client
117+ error_message = f"The client_id '{ client_id } ' is not found in client group in room '{ self .room_id } '"
118+ self .log .error (error_message )
119+ raise Exception (error_message )
70120
71121 def get_all (self , synced_only : bool = True ) -> list [YjsClient ]:
72122 """
73123 Returns a list of all synced clients.
74124 Set synced_only=False to also get desynced clients.
75125 """
76- return []
126+ if synced_only :
127+ return list (client for client in self .synced .values () if client .websocket and client .websocket .ws_connection )
128+ return list (client for client in self .desynced .values () if client .websocket and client .websocket .ws_connection )
77129
78130 def is_empty (self ) -> bool :
79131 """Returns whether the client group is empty."""
80- return False
132+ return len (self .synced ) == 0 and len (self .desynced ) == 0
133+
134+ async def _clean_desynced (self ) -> None :
135+ while True :
136+ try :
137+ await asyncio .sleep (self ._poll_interval_seconds )
138+ for (client_id , client ) in list (self .desynced .items ()):
139+ if client .last_modified <= datetime .now (timezone .utc ) - timedelta (seconds = self .desynced_timeout_seconds ):
140+ self .log .warning (f"Remove client '{ client_id } ' for room '{ self .room_id } ' since client does not become synced after { self .desynced_timeout_seconds } seconds." )
141+ self .remove (client_id )
142+ for (client_id , client ) in list (self .synced .items ()):
143+ if client .websocket is None or client .websocket .ws_connection is None :
144+ self .log .warning (f"Remove client '{ client_id } ' for room '{ self .room_id } ' since client does not become synced after { self .desynced_timeout_seconds } seconds." )
145+ self .remove (client_id )
146+ except asyncio .CancelledError :
147+ break
148+
0 commit comments