Skip to content

Commit bb81ccd

Browse files
committed
implement stopping all YRooms
1 parent d23e9f7 commit bb81ccd

File tree

6 files changed

+169
-75
lines changed

6 files changed

+169
-75
lines changed

jupyter_rtc_core/app.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,9 @@ class RtcExtensionApp(ExtensionApp):
3131
default_value=YRoomManager,
3232
).tag(config=True)
3333

34-
yroom_manager = Instance(
35-
klass=YRoomManager,
36-
help="An instance of the YRoom Manager.",
37-
allow_none=True
38-
).tag(config=True)
39-
34+
@property
35+
def yroom_manager(self) -> YRoomManager | None:
36+
return self.settings.get("yroom_manager", None)
4037

4138
def initialize(self):
4239
super().initialize()
@@ -71,3 +68,9 @@ def _link_jupyter_server_extension(self, server_app):
7168
c.ServerApp.session_manager_class = "jupyter_rtc_core.session_manager.YDocSessionManager"
7269
server_app.update_config(c)
7370
super()._link_jupyter_server_extension(server_app)
71+
72+
async def stop_extension(self):
73+
self.log.info("Stopping `jupyter_rtc_core` server extension.")
74+
if self.yroom_manager:
75+
await self.yroom_manager.stop()
76+

jupyter_rtc_core/rooms/yroom.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,16 @@ def __init__(
8585
# Start observers on `self.ydoc` and `self.awareness` to ensure new
8686
# updates are broadcast to all clients and saved to disk.
8787
self._awareness_subscription = self._awareness.observe(
88-
self.send_server_awareness
88+
self._on_awareness_update
8989
)
9090
self._ydoc_subscription = self._ydoc.observe(
91-
lambda event: self.write_sync_update(event.update)
91+
lambda event: self._on_ydoc_update(event.update)
9292
)
9393

9494
# Initialize message queue and start background task that routes new
9595
# messages in the message queue to the appropriate handler method.
9696
self._message_queue = asyncio.Queue()
97-
self._loop.create_task(self._on_new_message())
97+
self._loop.create_task(self._process_message_queue())
9898

9999
# Log notification that room is ready
100100
self.log.info(f"Room '{self.room_id}' initialized.")
@@ -144,7 +144,7 @@ def add_message(self, client_id: str, message: bytes) -> None:
144144
self._message_queue.put_nowait((client_id, message))
145145

146146

147-
async def _on_new_message(self) -> None:
147+
async def _process_message_queue(self) -> None:
148148
"""
149149
Async method that only runs when a new message arrives in the message
150150
queue. This method routes the message to a handler method based on the
@@ -281,7 +281,7 @@ def handle_sync_update(self, client_id: str, message: bytes) -> None:
281281
Handles incoming SyncUpdate messages by applying the update to the YDoc.
282282
283283
A SyncUpdate message will automatically be broadcast to all synced
284-
clients after this method is called via the `self.write_sync_update()`
284+
clients after this method is called via the `self._on_ydoc_update()`
285285
observer.
286286
"""
287287
# Remove client and kill websocket if received SyncUpdate when client is desynced
@@ -302,7 +302,7 @@ def handle_sync_update(self, client_id: str, message: bytes) -> None:
302302
return
303303

304304

305-
def write_sync_update(self, message_payload: bytes, client_id: str | None = None) -> None:
305+
def _on_ydoc_update(self, message_payload: bytes, client_id: str | None = None) -> None:
306306
"""
307307
This method is an observer on `self.ydoc` which:
308308
@@ -377,9 +377,10 @@ def _broadcast_message(self, message: bytes, message_type: Literal['AwarenessUpd
377377
)
378378
self.log.exception(e)
379379

380-
def send_server_awareness(self, type: str, changes: tuple[dict[str, Any], Any]) -> None:
380+
def _on_awareness_update(self, type: str, changes: tuple[dict[str, Any], Any]) -> None:
381381
"""
382-
Callback to broadcast the server awareness to clients.
382+
Observer on `self.awareness` that broadcasts AwarenessUpdate messages to
383+
all clients on update.
383384
384385
Arguments:
385386
type: The change type.
@@ -393,14 +394,22 @@ def send_server_awareness(self, type: str, changes: tuple[dict[str, Any], Any])
393394
message = pycrdt.create_awareness_message(state)
394395
self._broadcast_message(message, "AwarenessUpdate")
395396

396-
def stop(self) -> None:
397+
async def stop(self) -> None:
397398
"""
398-
Stop the YRoom.
399-
400-
TODO: stop file API & stop the message processing loop
399+
Stops the YRoom gracefully.
401400
"""
401+
# First, disconnect all clients by stopping the client group.
402+
self.clients.stop()
403+
404+
# Remove all observers, as updates no longer need to be broadcast
402405
self._ydoc.unobserve(self._ydoc_subscription)
403406
self._awareness.unobserve(self._awareness_subscription)
404407

405-
return
408+
# Finish processing all messages, then stop the queue to end the
409+
# `_process_message_queue()` background task.
410+
await self._message_queue.join()
411+
self._message_queue.shutdown(immediate=True)
406412

413+
# Finally, stop FileAPI and return. This saves the final content of the
414+
# JupyterYDoc in the process.
415+
await self.file_api.stop()

jupyter_rtc_core/rooms/yroom_file_api.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -161,38 +161,67 @@ def schedule_save(self) -> None:
161161

162162

163163
async def _process_scheduled_saves(self) -> None:
164+
"""
165+
Defines a background task that processes scheduled saves, after waiting
166+
for the JupyterYDoc content to be loaded.
167+
"""
168+
164169
# Wait for content to be loaded before processing scheduled saves
165170
await self._ydoc_content_loaded.wait()
166171

167172
while True:
168173
try:
169174
await self._scheduled_saves.get()
170175
except asyncio.QueueShutDown:
171-
return
176+
break
177+
178+
await self._save_jupyter_ydoc()
172179

173-
try:
174-
assert self.jupyter_ydoc
175-
path = self.get_path()
176-
content = self.jupyter_ydoc.source
177-
file_format = self.file_format
178-
file_type = self.file_type if self.file_type in SAVEABLE_FILE_TYPES else "file"
179-
180-
# Save the YDoc via the ContentsManager
181-
await ensure_async(self._contents_manager.save(
182-
{
183-
"format": file_format,
184-
"type": file_type,
185-
"content": content,
186-
},
187-
path
188-
))
189-
190-
# Mark 'dirty' as `False`. This hides the "unsaved changes" icon
191-
# in the JupyterLab tab rendering this YDoc in the frontend.
192-
self.jupyter_ydoc.dirty = False
193-
except Exception as e:
194-
self.log.error("An exception occurred when saving JupyterYDoc.")
195-
self.log.exception(e)
180+
181+
async def _save_jupyter_ydoc(self):
182+
"""
183+
Saves the JupyterYDoc to disk immediately.
184+
185+
This is a private method, and should only be called through the
186+
`_process_scheduled_saves()` task and the `stop()` method. Consumers
187+
should instead call `schedule_save()` to save the document.
188+
"""
189+
try:
190+
assert self.jupyter_ydoc
191+
path = self.get_path()
192+
content = self.jupyter_ydoc.source
193+
file_format = self.file_format
194+
file_type = self.file_type if self.file_type in SAVEABLE_FILE_TYPES else "file"
195+
196+
# Save the YDoc via the ContentsManager
197+
await ensure_async(self._contents_manager.save(
198+
{
199+
"format": file_format,
200+
"type": file_type,
201+
"content": content,
202+
},
203+
path
204+
))
205+
206+
# Mark 'dirty' as `False`. This hides the "unsaved changes" icon
207+
# in the JupyterLab tab rendering this YDoc in the frontend.
208+
self.jupyter_ydoc.dirty = False
209+
except Exception as e:
210+
self.log.error("An exception occurred when saving JupyterYDoc.")
211+
self.log.exception(e)
212+
213+
async def stop(self) -> None:
214+
"""
215+
Gracefully stops the YRoomFileAPI, saving the content of
216+
`self.jupyter_ydoc` before exiting.
217+
"""
218+
# Do nothing if content was not loaded first. This prevents overwriting
219+
# the existing file with an empty JupyterYDoc.
220+
if not (self._ydoc_content_loaded.is_set()):
221+
return
222+
223+
# Save the file and return.
224+
await self._save_jupyter_ydoc()
196225

197226

198227
# see https://github.com/jupyterlab/jupyter-collaboration/blob/main/projects/jupyter-server-ydoc/jupyter_server_ydoc/loaders.py#L146-L149

jupyter_rtc_core/rooms/yroom_manager.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from __future__ import annotations
2-
from typing import Any, Dict, Optional
3-
from traitlets import HasTraits, Instance, default
42

53
from .yroom import YRoom
64
from typing import TYPE_CHECKING
@@ -27,8 +25,9 @@ def __init__(
2725
self.contents_manager = contents_manager
2826
self.loop = loop
2927
self.log = log
30-
self._rooms_by_id = {}
28+
3129
# Initialize dictionary of YRooms, keyed by room ID
30+
self._rooms_by_id = {}
3231

3332

3433
@property
@@ -66,16 +65,23 @@ def get_room(self, room_id: str) -> YRoom | None:
6665
return None
6766

6867

69-
def delete_room(self, room_id: str) -> None:
68+
async def delete_room(self, room_id: str) -> None:
7069
"""
71-
Deletes a YRoom given a room ID.
72-
73-
TODO: finish implementing YRoom.stop(), and delete empty rooms w/ no
74-
live kernels automatically in a background task.
70+
Gracefully deletes a YRoom given a room ID. This stops the YRoom first,
71+
which finishes applying all updates & saves the content automatically.
7572
"""
7673
yroom = self._rooms_by_id.get(room_id, None)
7774
if not yroom:
7875
return
7976

80-
yroom.stop()
77+
await yroom.stop()
8178
del self._rooms_by_id[room_id]
79+
80+
81+
async def stop(self) -> None:
82+
"""
83+
Gracefully deletes each `YRoom`. See `delete_room()` for more info.
84+
"""
85+
for room_id in self._rooms_by_id.keys():
86+
await self.delete_room(room_id)
87+

jupyter_rtc_core/websockets/clients.py

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313

1414
if TYPE_CHECKING:
1515
from tornado.websocket import WebSocketHandler
16+
17+
1618
class YjsClient:
1719
"""Data model that represents all data associated
1820
with a user connecting to a YDoc through JupyterLab."""
1921

20-
websocket: WebSocketHandler | None
22+
websocket: WebSocketHandler
2123
"""The Tornado WebSocketHandler handling the WS connection to this client."""
2224
id: str
2325
"""UUIDv4 string that uniquely identifies this client."""
@@ -27,21 +29,23 @@ class YjsClient:
2729
_synced: bool
2830

2931
def __init__(self, websocket):
30-
self.websocket: WebSocketHandler | None = websocket
31-
self.id: str = str(uuid.uuid4())
32-
self._synced: bool = False
32+
self.websocket = websocket
33+
self.id = str(uuid.uuid4())
34+
self._synced = False
3335
self.last_modified = datetime.now(timezone.utc)
3436

37+
3538
@property
36-
def synced(self):
39+
def synced(self) -> bool:
3740
"""
3841
Indicates whether the initial Client SS1 + Server SS2 handshake has been
3942
completed.
4043
"""
4144
return self._synced
4245

46+
4347
@synced.setter
44-
def synced(self, v: bool):
48+
def synced(self, v: bool) -> None:
4549
self._synced = v
4650
self.last_modified = datetime.now(timezone.utc)
4751

@@ -70,6 +74,11 @@ class YjsClientGroup:
7074
"""The poll time interval used while auto removing desynced clients"""
7175
desynced_timeout_seconds: int
7276
"""The max time period in seconds that a desynced client does not become synced before get auto removed from desynced dict"""
77+
_stopped: bool
78+
"""
79+
Whether the client group has been stopped. If `True`, this client group will
80+
ignore all future calls to `add()`.
81+
"""
7382

7483
def __init__(self, *, room_id: str, log: Logger, loop: asyncio.AbstractEventLoop, poll_interval_seconds: int = 60, desynced_timeout_seconds: int = 120):
7584
self.room_id = room_id
@@ -80,9 +89,20 @@ def __init__(self, *, room_id: str, log: Logger, loop: asyncio.AbstractEventLoop
8089
# self.loop.create_task(self._clean_desynced())
8190
self._poll_interval_seconds = poll_interval_seconds
8291
self.desynced_timeout_seconds = desynced_timeout_seconds
92+
self._stopped = False
8393

84-
def add(self, websocket: WebSocketHandler) -> str:
85-
"""Adds a pending client to the group. Returns a client ID."""
94+
def add(self, websocket: WebSocketHandler) -> str | None:
95+
"""
96+
Adds a new Websocket as a client to the group.
97+
98+
Returns a client ID if the client group is active.
99+
100+
Returns `None` if the client group is stopped, which is set after
101+
`stop()` is called.
102+
"""
103+
if self._stopped:
104+
return
105+
86106
client = YjsClient(websocket)
87107
self.desynced[client.id] = client
88108
return client.id
@@ -101,13 +121,16 @@ def mark_desynced(self, client_id: str) -> None:
101121

102122
def remove(self, client_id: str) -> None:
103123
"""Removes a client from the group."""
104-
if client := self.desynced.pop(client_id, None) is None:
124+
client = self.desynced.pop(client_id, None)
125+
if not client:
105126
client = self.synced.pop(client_id, None)
106-
if client and client.websocket and client.websocket.ws_connection:
107-
try:
108-
client.websocket.close()
109-
except Exception as e:
110-
self.log.exception(f"An exception occurred when remove client '{client_id}' for room '{self.room_id}': {e}")
127+
if not client:
128+
return
129+
130+
try:
131+
client.websocket.close()
132+
except Exception as e:
133+
self.log.exception(f"An exception occurred when remove client '{client_id}' for room '{self.room_id}': {e}")
111134

112135
def get(self, client_id: str) -> YjsClient:
113136
"""
@@ -128,10 +151,12 @@ def get_all(self, synced_only: bool = True) -> list[YjsClient]:
128151
Returns a list of all synced clients.
129152
Set synced_only=False to also get desynced clients.
130153
"""
131-
if synced_only:
132-
return list(client for client in self.synced.values() if client.websocket and client.websocket.ws_connection)
133-
return list(client for client in self.desynced.values() if client.websocket and client.websocket.ws_connection)
134-
154+
all_clients = [c for c in self.synced.values()]
155+
if not synced_only:
156+
all_clients += [c for c in self.desynced.values()]
157+
158+
return all_clients
159+
135160
def is_empty(self) -> bool:
136161
"""Returns whether the client group is empty."""
137162
return len(self.synced) == 0 and len(self.desynced) == 0
@@ -150,4 +175,16 @@ async def _clean_desynced(self) -> None:
150175
self.remove(client_id)
151176
except asyncio.CancelledError:
152177
break
153-
178+
179+
def stop(self):
180+
"""
181+
Closes all Websocket connections with close code 1001 (server
182+
shutting down) and ignores any future calls to `add()`.
183+
"""
184+
# Close all Websocket connections
185+
for client in self.get_all(synced_only=False):
186+
client.websocket.close(code=1001)
187+
188+
# Set `_stopped` to `True` to ignore future calls to `add()`
189+
self._stopped = True
190+

0 commit comments

Comments
 (0)