Skip to content

Commit ff67c1f

Browse files
jzhang20133Jialin Zhang
andauthored
adding yjsclient and yjsclientgroup and server awareness observer
* adding yjsclient and yjsclientgroup and server awareness observer * add websocket check --------- Co-authored-by: Jialin Zhang <[email protected]>
1 parent 27f8a62 commit ff67c1f

File tree

2 files changed

+130
-38
lines changed

2 files changed

+130
-38
lines changed

jupyter_rtc_core/rooms/yroom.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,43 @@
1212
from typing import Literal, Tuple
1313

1414
class YRoom:
15+
"""A Room to manage all client connection to one notebook file"""
16+
room_id: str
17+
"""Room Id"""
1518
ydoc: pycrdt.Doc
19+
"""Ydoc"""
1620
awareness: pycrdt.Awareness
21+
"""Ydoc awareness object"""
1722
loop: asyncio.AbstractEventLoop
23+
"""Event loop"""
1824
log: Logger
25+
"""Log object"""
1926
_client_group: YjsClientGroup
27+
"""Client group to manage synced and desynced clients"""
2028
_message_queue: asyncio.Queue[Tuple[str, bytes]]
29+
"""A message queue per room to keep websocket messages in order"""
2130

2231

23-
def __init__(self, log: Logger, loop: asyncio.AbstractEventLoop):
32+
def __init__(self, *, room_id: str, log: Logger, loop: asyncio.AbstractEventLoop):
2433
# Bind instance attributes
2534
self.log = log
2635
self.loop = loop
36+
self.room_id = room_id
2737

2838
# Initialize YDoc, YAwareness, YjsClientGroup, and message queue
2939
self.ydoc = pycrdt.Doc()
3040
self.awareness = pycrdt.Awareness(ydoc=self.ydoc)
31-
self._client_group = YjsClientGroup()
41+
self.awareness.observe(self.send_server_awareness)
42+
self._client_group = YjsClientGroup(room_id=room_id, log=self.log, loop=self.loop)
3243
self._message_queue = asyncio.Queue()
44+
45+
# Start observer on the `ydoc` to ensure new updates are broadcast to
46+
# all clients and saved to disk.
47+
self.ydoc.observe(lambda event: self.write_sync_update(event.update))
3348

3449
# Start background task that routes new messages in the message queue
3550
# to the appropriate handler method.
3651
self.loop.create_task(self._on_new_message())
37-
38-
# Start observer on the `ydoc` to ensure new updates are broadcast to
39-
# all clients and saved to disk.
40-
self.ydoc.observe(lambda event: self.write_sync_update(event.update))
4152

4253

4354
@property
@@ -124,7 +135,7 @@ def handle_sync_step1(self, client_id: str, message: bytes) -> None:
124135
- Sending a new SyncStep1 message immediately after.
125136
"""
126137
# Mark client as desynced
127-
new_client = self.clients.get(client_id, synced_only=False)
138+
new_client = self.clients.get(client_id)
128139
self.clients.mark_desynced(client_id)
129140

130141
# Compute SyncStep2 reply
@@ -146,7 +157,6 @@ def handle_sync_step1(self, client_id: str, message: bytes) -> None:
146157
# TODO: remove the assert once websocket is made required
147158
assert isinstance(new_client.websocket, WebSocketHandler)
148159
new_client.websocket.write_message(sync_step2_message)
149-
self.clients.mark_synced(client_id)
150160
except Exception as e:
151161
self.log.error(
152162
"An exception occurred when writing the SyncStep2 reply "
@@ -155,6 +165,8 @@ def handle_sync_step1(self, client_id: str, message: bytes) -> None:
155165
self.log.exception(e)
156166
return
157167

168+
self.clients.mark_synced(client_id)
169+
158170
# Send SyncStep1 message
159171
try:
160172
assert isinstance(new_client.websocket, WebSocketHandler)
@@ -197,9 +209,10 @@ def handle_sync_update(self, client_id: str, message: bytes) -> None:
197209
clients after this method is called via the `self.write_sync_update()`
198210
observer.
199211
"""
200-
# Ignore the message if client is desynced
212+
# Remove client and kill websocket if received SyncUpdate when client is desynced
201213
if self._should_ignore_update(client_id, "SyncUpdate"):
202-
return
214+
self.log.error(f"Should not receive SyncUpdate message when double handshake is not completed for client '{client_id}' and room '{self.room_id}'")
215+
self._client_group.remove(client_id)
203216

204217
# Apply the SyncUpdate to the YDoc
205218
try:
@@ -235,10 +248,6 @@ def write_sync_update(self, message_payload: bytes, client_id: str | None = None
235248

236249

237250
def handle_awareness_update(self, client_id: str, message: bytes) -> None:
238-
# Ignore the message if client is desynced
239-
if self._should_ignore_update(client_id, "AwarenessUpdate"):
240-
return
241-
242251
# Apply the AwarenessUpdate message
243252
try:
244253
message_payload = message[1:]
@@ -263,7 +272,7 @@ def _should_ignore_update(self, client_id: str, message_type: Literal['Awareness
263272
more readable warnings.
264273
"""
265274

266-
client = self.clients.get(client_id, synced_only=False)
275+
client = self.clients.get(client_id)
267276
if not client.synced:
268277
self.log.warning(
269278
f"Ignoring a {message_type} message from client "
@@ -293,7 +302,22 @@ def _broadcast_message(self, message: bytes, message_type: Literal['AwarenessUpd
293302
f"to client '{client.id}:'"
294303
)
295304
self.log.exception(e)
305+
306+
def send_server_awareness(self, type: str, changes: tuple[dict[str, Any], Any]) -> None:
307+
"""
308+
Callback to broadcast the server awareness to clients.
296309
310+
Arguments:
311+
type: The change type.
312+
changes: The awareness changes.
313+
"""
314+
if type != "update" or changes[1] != "local":
315+
return
316+
317+
updated_clients = [v for value in changes[0].values() for v in value]
318+
state = self.awareness.encode_awareness_update(updated_clients)
319+
message = pycrdt.create_awareness_message(state)
320+
self._broadcast_message(message, "AwarenessUpdate")
297321

298322
def stop(self) -> None:
299323
# TODO: requires YRoomLoader implementation

jupyter_rtc_core/websockets/clients.py

Lines changed: 91 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,40 @@
55
"""
66

77
from __future__ import annotations
8+
from datetime import timedelta, timezone
9+
import datetime
10+
from logging import Logger
811
from typing import TYPE_CHECKING
912
from dataclasses import dataclass
1013
import uuid
14+
import asyncio
1115

1216
if TYPE_CHECKING:
1317
from tornado.websocket import WebSocketHandler
14-
15-
@dataclass
1618
class YjsClient:
1719
"""Data model that represents all data associated
1820
with a user connecting to a YDoc through JupyterLab."""
1921

20-
websocket: WebSocketHandler | None = None
21-
"""
22-
The Tornado WebSocketHandler handling the WS connection to this
23-
client.
24-
25-
TODO: make this required
26-
"""
27-
28-
id: str = str(uuid.uuid4())
22+
websocket: WebSocketHandler | None
23+
"""The Tornado WebSocketHandler handling the WS connection to this client."""
24+
id: str
2925
"""UUIDv4 string that uniquely identifies this client."""
30-
31-
synced: bool = False
26+
synced: bool
3227
"""Indicates whether the SS1 + SS2 handshake has been completed."""
28+
last_modified: datetime
29+
"""Indicates the last modified time when synced state is modified"""
30+
31+
def __init__(self, websocket):
32+
self.websocket: WebSocketHandler | None = websocket
33+
self.id: str = str(uuid.uuid4())
34+
self.synced: bool = False
35+
self.last_modified = datetime.now(timezone.utc)
36+
37+
38+
@synced.setter
39+
def synced(self, v: bool):
40+
self.synced = v
41+
self.last_modified = datetime.now(timezone.utc)
3342

3443
class YjsClientGroup:
3544
"""
@@ -39,42 +48,101 @@ class YjsClientGroup:
3948
New clients start as desynced. Consumers should call mark_synced() to mark a
4049
new client as synced once the SS1 + SS2 handshake is complete.
4150
42-
TODO: Automatically removes desynced clients if they do not become synced after
51+
Automatically removes desynced clients if they do not become synced after
4352
a certain timeout.
4453
"""
54+
room_id: str
55+
"""Room Id for associated YRoom"""
4556
synced: dict[str, YjsClient]
57+
"""A dict of client_id and synced YjsClient mapping"""
4658
desynced: dict[str, YjsClient]
59+
"""A dict of client_id and desynced YjsClient mapping"""
60+
log: Logger
61+
"""Log object"""
62+
loop: asyncio.AbstractEventLoop
63+
"""Event loop"""
64+
poll_interval_seconds: int
65+
"""The poll time interval used while auto removing desynced clients"""
66+
desynced_timeout_seconds: int
67+
"""The max time period in seconds that a desynced client does not become synced before get auto removed from desynced dict"""
68+
69+
def __init__(self, *, room_id: str, log: Logger, loop: asyncio.AbstractEventLoop, poll_interval_seconds: int = 60, desynced_timeout_seconds: int = 120):
70+
self.room_id = room_id
71+
self.synced: dict[str, YjsClient] = {}
72+
self.desynced: dict[str, YjsClient] = {}
73+
self.log = log
74+
self.loop = loop
75+
self.loop.create_task(self._clean_desynced())
76+
self.poll_interval_seconds = poll_interval_seconds
77+
self.desynced_timeout_seconds = desynced_timeout_seconds
4778

4879
def add(self, websocket: WebSocketHandler) -> str:
4980
"""Adds a pending client to the group. Returns a client ID."""
50-
return ""
81+
client = YjsClient(websocket)
82+
self.desynced[client.id] = client
83+
return client.id
5184

5285
def mark_synced(self, client_id: str) -> None:
5386
"""Marks a client as synced."""
54-
return
87+
if client := self.desynced.pop(client_id, None):
88+
client.synced = True
89+
self.synced[client.id] = client
5590

5691
def mark_desynced(self, client_id: str) -> None:
5792
"""Marks a client as desynced."""
58-
return
93+
if client := self.synced.pop(client_id, None):
94+
client.synced = False
95+
self.desynced[client.id] = client
5996

6097
def remove(self, client_id: str) -> None:
6198
"""Removes a client from the group."""
62-
return
99+
if client := self.desynced.pop(client_id, None) is None:
100+
client = self.synced.pop(client_id, None)
101+
if client and client.websocket and client.websocket.ws_connection:
102+
try:
103+
client.websocket.close()
104+
except Exception as e:
105+
self.log.exception(f"An exception occurred when remove client '{client_id}' for room '{self.room_id}': {e}")
63106

64-
def get(self, client_id: str, synced_only: bool = True) -> YjsClient:
107+
def get(self, client_id: str) -> YjsClient:
65108
"""
66109
Gets a client from its ID.
67-
Set synced_only=False to also get desynced clients.
68110
"""
69-
return YjsClient()
111+
if client_id in self.desynced:
112+
client = self.desynced[client_id]
113+
if client_id in self.synced:
114+
client = self.synced[client_id]
115+
if client.websocket and client.websocket.ws_connection:
116+
return client
117+
error_message = f"The client_id '{client_id}' is not found in client group in room '{self.room_id}'"
118+
self.log.error(error_message)
119+
raise Exception(error_message)
70120

71121
def get_all(self, synced_only: bool = True) -> list[YjsClient]:
72122
"""
73123
Returns a list of all synced clients.
74124
Set synced_only=False to also get desynced clients.
75125
"""
76-
return []
126+
if synced_only:
127+
return list(client for client in self.synced.values() if client.websocket and client.websocket.ws_connection)
128+
return list(client for client in self.desynced.values() if client.websocket and client.websocket.ws_connection)
77129

78130
def is_empty(self) -> bool:
79131
"""Returns whether the client group is empty."""
80-
return False
132+
return len(self.synced) == 0 and len(self.desynced) == 0
133+
134+
async def _clean_desynced(self) -> None:
135+
while True:
136+
try:
137+
await asyncio.sleep(self._poll_interval_seconds)
138+
for (client_id, client) in list(self.desynced.items()):
139+
if client.last_modified <= datetime.now(timezone.utc) - timedelta(seconds=self.desynced_timeout_seconds):
140+
self.log.warning(f"Remove client '{client_id}' for room '{self.room_id}' since client does not become synced after {self.desynced_timeout_seconds} seconds.")
141+
self.remove(client_id)
142+
for (client_id, client) in list(self.synced.items()):
143+
if client.websocket is None or client.websocket.ws_connection is None:
144+
self.log.warning(f"Remove client '{client_id}' for room '{self.room_id}' since client does not become synced after {self.desynced_timeout_seconds} seconds.")
145+
self.remove(client_id)
146+
except asyncio.CancelledError:
147+
break
148+

0 commit comments

Comments
 (0)