diff --git a/jupyter_rtc_core/app.py b/jupyter_rtc_core/app.py index 8a4f0c4..4943055 100644 --- a/jupyter_rtc_core/app.py +++ b/jupyter_rtc_core/app.py @@ -31,12 +31,9 @@ class RtcExtensionApp(ExtensionApp): default_value=YRoomManager, ).tag(config=True) - yroom_manager = Instance( - klass=YRoomManager, - help="An instance of the YRoom Manager.", - allow_none=True - ).tag(config=True) - + @property + def yroom_manager(self) -> YRoomManager | None: + return self.settings.get("yroom_manager", None) def initialize(self): super().initialize() @@ -59,7 +56,6 @@ def get_fileid_manager(): loop=loop, log=log ) - pass def _link_jupyter_server_extension(self, server_app): @@ -71,3 +67,9 @@ def _link_jupyter_server_extension(self, server_app): c.ServerApp.session_manager_class = "jupyter_rtc_core.session_manager.YDocSessionManager" server_app.update_config(c) super()._link_jupyter_server_extension(server_app) + + async def stop_extension(self): + self.log.info("Stopping `jupyter_rtc_core` server extension.") + if self.yroom_manager: + await self.yroom_manager.stop() + self.log.info("`jupyter_rtc_core` server extension is shut down. Goodbye!") diff --git a/jupyter_rtc_core/rooms/yroom.py b/jupyter_rtc_core/rooms/yroom.py index 90b260a..37d557c 100644 --- a/jupyter_rtc_core/rooms/yroom.py +++ b/jupyter_rtc_core/rooms/yroom.py @@ -85,16 +85,16 @@ def __init__( # Start observers on `self.ydoc` and `self.awareness` to ensure new # updates are broadcast to all clients and saved to disk. self._awareness_subscription = self._awareness.observe( - self.send_server_awareness + self._on_awareness_update ) self._ydoc_subscription = self._ydoc.observe( - lambda event: self.write_sync_update(event.update) + lambda event: self._on_ydoc_update(event.update) ) # Initialize message queue and start background task that routes new # messages in the message queue to the appropriate handler method. self._message_queue = asyncio.Queue() - self._loop.create_task(self._on_new_message()) + self._loop.create_task(self._process_message_queue()) # Log notification that room is ready self.log.info(f"Room '{self.room_id}' initialized.") @@ -144,7 +144,7 @@ def add_message(self, client_id: str, message: bytes) -> None: self._message_queue.put_nowait((client_id, message)) - async def _on_new_message(self) -> None: + async def _process_message_queue(self) -> None: """ Async method that only runs when a new message arrives in the message queue. This method routes the message to a handler method based on the @@ -162,41 +162,54 @@ async def _on_new_message(self) -> None: except asyncio.QueueShutDown: break - # Handle Awareness messages + # Determine message type & subtype from header message_type = message[0] - if message_type == YMessageType.AWARENESS: + sync_message_subtype = "*" + if message_type == YMessageType.SYNC and len(message) >= 2: + sync_message_subtype = message[1] + + # Determine if message is invalid + invalid_message_type = message_type not in YMessageType + invalid_sync_message_type = message_type == YMessageType.SYNC and sync_message_subtype not in YSyncMessageSubtype + invalid_message = invalid_message_type or invalid_sync_message_type + + # Handle invalid messages by logging a warning and ignoring + if invalid_message: + self.log.warning( + "Ignoring an unrecognized message with header " + f"'{message_type},{sync_message_subtype}' from client " + "'{client_id}'. Messages must have one of the following " + "headers: '0,0' (SyncStep1), '0,1' (SyncStep2), " + "'0,2' (SyncUpdate), or '1,*' (AwarenessUpdate)." + ) + # Handle Awareness messages + elif message_type == YMessageType.AWARENESS: self.log.debug(f"Received AwarenessUpdate from '{client_id}'.") self.handle_awareness_update(client_id, message) self.log.debug(f"Handled AwarenessUpdate from '{client_id}'.") - continue - # Handle Sync messages - assert message_type == YMessageType.SYNC - message_subtype = message[1] if len(message) >= 2 else None - if message_subtype == YSyncMessageSubtype.SYNC_STEP1: + elif sync_message_subtype == YSyncMessageSubtype.SYNC_STEP1: self.log.info(f"Received SS1 from '{client_id}'.") self.handle_sync_step1(client_id, message) self.log.info(f"Handled SS1 from '{client_id}'.") - continue - elif message_subtype == YSyncMessageSubtype.SYNC_STEP2: + elif sync_message_subtype == YSyncMessageSubtype.SYNC_STEP2: self.log.info(f"Received SS2 from '{client_id}'.") self.handle_sync_step2(client_id, message) self.log.info(f"Handled SS2 from '{client_id}'.") - continue - elif message_subtype == YSyncMessageSubtype.SYNC_UPDATE: + elif sync_message_subtype == YSyncMessageSubtype.SYNC_UPDATE: self.log.info(f"Received SyncUpdate from '{client_id}'.") self.handle_sync_update(client_id, message) self.log.info(f"Handled SyncUpdate from '{client_id}'.") - continue - else: - self.log.warning( - "Ignoring an unrecognized message with header " - f"'{message_type},{message_subtype}' from client " - "'{client_id}'. Messages must have one of the following " - "headers: '0,0' (SyncStep1), '0,1' (SyncStep2), " - "'0,2' (SyncUpdate), or '1,*' (AwarenessUpdate)." - ) - continue + + # Finally, inform the asyncio Queue that the task was complete + # This is required for `self._message_queue.join()` to unblock once + # queue is empty in `self.stop()`. + self._message_queue.task_done() + + self.log.info( + "Stopped `self._process_message_queue()` background task " + f"for YRoom '{self.room_id}'." + ) def handle_sync_step1(self, client_id: str, message: bytes) -> None: @@ -281,7 +294,7 @@ def handle_sync_update(self, client_id: str, message: bytes) -> None: Handles incoming SyncUpdate messages by applying the update to the YDoc. A SyncUpdate message will automatically be broadcast to all synced - clients after this method is called via the `self.write_sync_update()` + clients after this method is called via the `self._on_ydoc_update()` observer. """ # Remove client and kill websocket if received SyncUpdate when client is desynced @@ -302,7 +315,7 @@ def handle_sync_update(self, client_id: str, message: bytes) -> None: return - def write_sync_update(self, message_payload: bytes, client_id: str | None = None) -> None: + def _on_ydoc_update(self, message_payload: bytes, client_id: str | None = None) -> None: """ This method is an observer on `self.ydoc` which: @@ -377,9 +390,10 @@ def _broadcast_message(self, message: bytes, message_type: Literal['AwarenessUpd ) self.log.exception(e) - def send_server_awareness(self, type: str, changes: tuple[dict[str, Any], Any]) -> None: + def _on_awareness_update(self, type: str, changes: tuple[dict[str, Any], Any]) -> None: """ - Callback to broadcast the server awareness to clients. + Observer on `self.awareness` that broadcasts AwarenessUpdate messages to + all clients on update. Arguments: type: The change type. @@ -393,14 +407,22 @@ def send_server_awareness(self, type: str, changes: tuple[dict[str, Any], Any]) message = pycrdt.create_awareness_message(state) self._broadcast_message(message, "AwarenessUpdate") - def stop(self) -> None: + async def stop(self) -> None: """ - Stop the YRoom. - - TODO: stop file API & stop the message processing loop + Stops the YRoom gracefully. """ + # First, disconnect all clients by stopping the client group. + self.clients.stop() + + # Remove all observers, as updates no longer need to be broadcast self._ydoc.unobserve(self._ydoc_subscription) self._awareness.unobserve(self._awareness_subscription) - return + # Finish processing all messages, then stop the queue to end the + # `_process_message_queue()` background task. + await self._message_queue.join() + self._message_queue.shutdown(immediate=True) + # Finally, stop FileAPI and return. This saves the final content of the + # JupyterYDoc in the process. + await self.file_api.stop() diff --git a/jupyter_rtc_core/rooms/yroom_file_api.py b/jupyter_rtc_core/rooms/yroom_file_api.py index c05e959..27b4882 100644 --- a/jupyter_rtc_core/rooms/yroom_file_api.py +++ b/jupyter_rtc_core/rooms/yroom_file_api.py @@ -161,6 +161,11 @@ def schedule_save(self) -> None: async def _process_scheduled_saves(self) -> None: + """ + Defines a background task that processes scheduled saves, after waiting + for the JupyterYDoc content to be loaded. + """ + # Wait for content to be loaded before processing scheduled saves await self._ydoc_content_loaded.wait() @@ -168,31 +173,64 @@ async def _process_scheduled_saves(self) -> None: try: await self._scheduled_saves.get() except asyncio.QueueShutDown: - return + break + + await self._save_jupyter_ydoc() - try: - assert self.jupyter_ydoc - path = self.get_path() - content = self.jupyter_ydoc.source - file_format = self.file_format - file_type = self.file_type if self.file_type in SAVEABLE_FILE_TYPES else "file" - - # Save the YDoc via the ContentsManager - await ensure_async(self._contents_manager.save( - { - "format": file_format, - "type": file_type, - "content": content, - }, - path - )) - - # Mark 'dirty' as `False`. This hides the "unsaved changes" icon - # in the JupyterLab tab rendering this YDoc in the frontend. - self.jupyter_ydoc.dirty = False - except Exception as e: - self.log.error("An exception occurred when saving JupyterYDoc.") - self.log.exception(e) + self.log.info( + "Stopped `self._process_scheduled_save()` background task " + f"for YRoom '{self.room_id}'." + ) + + + async def _save_jupyter_ydoc(self): + """ + Saves the JupyterYDoc to disk immediately. + + This is a private method, and should only be called through the + `_process_scheduled_saves()` task and the `stop()` method. Consumers + should instead call `schedule_save()` to save the document. + """ + try: + assert self.jupyter_ydoc + path = self.get_path() + content = self.jupyter_ydoc.source + file_format = self.file_format + file_type = self.file_type if self.file_type in SAVEABLE_FILE_TYPES else "file" + + # Save the YDoc via the ContentsManager + await ensure_async(self._contents_manager.save( + { + "format": file_format, + "type": file_type, + "content": content, + }, + path + )) + + # Mark 'dirty' as `False`. This hides the "unsaved changes" icon + # in the JupyterLab tab rendering this YDoc in the frontend. + self.jupyter_ydoc.dirty = False + except Exception as e: + self.log.error("An exception occurred when saving JupyterYDoc.") + self.log.exception(e) + + async def stop(self) -> None: + """ + Gracefully stops the YRoomFileAPI, saving the content of + `self.jupyter_ydoc` before exiting. + """ + # Stop the `self._process_scheduled_saves()` background task + # immediately. This is safe since we save before stopping anyways. + self._scheduled_saves.shutdown(immediate=True) + + # Do nothing if content was not loaded first. This prevents overwriting + # the existing file with an empty JupyterYDoc. + if not (self._ydoc_content_loaded.is_set()): + return + + # Save the file and return. + await self._save_jupyter_ydoc() # see https://github.com/jupyterlab/jupyter-collaboration/blob/main/projects/jupyter-server-ydoc/jupyter_server_ydoc/loaders.py#L146-L149 diff --git a/jupyter_rtc_core/rooms/yroom_manager.py b/jupyter_rtc_core/rooms/yroom_manager.py index 91dbba4..3005f7b 100644 --- a/jupyter_rtc_core/rooms/yroom_manager.py +++ b/jupyter_rtc_core/rooms/yroom_manager.py @@ -1,6 +1,4 @@ from __future__ import annotations -from typing import Any, Dict, Optional -from traitlets import HasTraits, Instance, default from .yroom import YRoom from typing import TYPE_CHECKING @@ -27,8 +25,9 @@ def __init__( self.contents_manager = contents_manager self.loop = loop self.log = log - self._rooms_by_id = {} + # Initialize dictionary of YRooms, keyed by room ID + self._rooms_by_id = {} @property @@ -66,16 +65,26 @@ def get_room(self, room_id: str) -> YRoom | None: return None - def delete_room(self, room_id: str) -> None: + async def delete_room(self, room_id: str) -> None: """ - Deletes a YRoom given a room ID. - - TODO: finish implementing YRoom.stop(), and delete empty rooms w/ no - live kernels automatically in a background task. + Gracefully deletes a YRoom given a room ID. This stops the YRoom first, + which finishes applying all updates & saves the content automatically. """ - yroom = self._rooms_by_id.get(room_id, None) + yroom = self._rooms_by_id.pop(room_id, None) if not yroom: return - yroom.stop() - del self._rooms_by_id[room_id] + self.log.info(f"Stopping YRoom '{room_id}'.") + await yroom.stop() + self.log.info(f"Stopped YRoom '{room_id}'.") + + + async def stop(self) -> None: + """ + Gracefully deletes each `YRoom`. See `delete_room()` for more info. + """ + self.log.info(f"Stopping `YRoomManager` and deleting all YRooms.") + room_ids = list(self._rooms_by_id.keys()) + for room_id in room_ids: + await self.delete_room(room_id) + self.log.info(f"Stopped `YRoomManager` and deleted all YRooms.") diff --git a/jupyter_rtc_core/websockets/clients.py b/jupyter_rtc_core/websockets/clients.py index a4f2735..55f85da 100644 --- a/jupyter_rtc_core/websockets/clients.py +++ b/jupyter_rtc_core/websockets/clients.py @@ -13,11 +13,13 @@ if TYPE_CHECKING: from tornado.websocket import WebSocketHandler + + class YjsClient: """Data model that represents all data associated with a user connecting to a YDoc through JupyterLab.""" - websocket: WebSocketHandler | None + websocket: WebSocketHandler """The Tornado WebSocketHandler handling the WS connection to this client.""" id: str """UUIDv4 string that uniquely identifies this client.""" @@ -27,21 +29,23 @@ class YjsClient: _synced: bool def __init__(self, websocket): - self.websocket: WebSocketHandler | None = websocket - self.id: str = str(uuid.uuid4()) - self._synced: bool = False + self.websocket = websocket + self.id = str(uuid.uuid4()) + self._synced = False self.last_modified = datetime.now(timezone.utc) + @property - def synced(self): + def synced(self) -> bool: """ Indicates whether the initial Client SS1 + Server SS2 handshake has been completed. """ return self._synced + @synced.setter - def synced(self, v: bool): + def synced(self, v: bool) -> None: self._synced = v self.last_modified = datetime.now(timezone.utc) @@ -70,6 +74,11 @@ class YjsClientGroup: """The poll time interval used while auto removing desynced clients""" desynced_timeout_seconds: int """The max time period in seconds that a desynced client does not become synced before get auto removed from desynced dict""" + _stopped: bool + """ + Whether the client group has been stopped. If `True`, this client group will + ignore all future calls to `add()`. + """ def __init__(self, *, room_id: str, log: Logger, loop: asyncio.AbstractEventLoop, poll_interval_seconds: int = 60, desynced_timeout_seconds: int = 120): self.room_id = room_id @@ -80,9 +89,20 @@ def __init__(self, *, room_id: str, log: Logger, loop: asyncio.AbstractEventLoop # self.loop.create_task(self._clean_desynced()) self._poll_interval_seconds = poll_interval_seconds self.desynced_timeout_seconds = desynced_timeout_seconds + self._stopped = False - def add(self, websocket: WebSocketHandler) -> str: - """Adds a pending client to the group. Returns a client ID.""" + def add(self, websocket: WebSocketHandler) -> str | None: + """ + Adds a new Websocket as a client to the group. + + Returns a client ID if the client group is active. + + Returns `None` if the client group is stopped, which is set after + `stop()` is called. + """ + if self._stopped: + return + client = YjsClient(websocket) self.desynced[client.id] = client return client.id @@ -101,13 +121,16 @@ def mark_desynced(self, client_id: str) -> None: def remove(self, client_id: str) -> None: """Removes a client from the group.""" - if client := self.desynced.pop(client_id, None) is None: + client = self.desynced.pop(client_id, None) + if not client: client = self.synced.pop(client_id, None) - if client and client.websocket and client.websocket.ws_connection: - try: - client.websocket.close() - except Exception as e: - self.log.exception(f"An exception occurred when remove client '{client_id}' for room '{self.room_id}': {e}") + if not client: + return + + try: + client.websocket.close() + except Exception as e: + self.log.exception(f"An exception occurred when remove client '{client_id}' for room '{self.room_id}': {e}") def get(self, client_id: str) -> YjsClient: """ @@ -128,10 +151,12 @@ def get_all(self, synced_only: bool = True) -> list[YjsClient]: Returns a list of all synced clients. Set synced_only=False to also get desynced clients. """ - if synced_only: - return list(client for client in self.synced.values() if client.websocket and client.websocket.ws_connection) - return list(client for client in self.desynced.values() if client.websocket and client.websocket.ws_connection) - + all_clients = [c for c in self.synced.values()] + if not synced_only: + all_clients += [c for c in self.desynced.values()] + + return all_clients + def is_empty(self) -> bool: """Returns whether the client group is empty.""" return len(self.synced) == 0 and len(self.desynced) == 0 @@ -150,4 +175,29 @@ async def _clean_desynced(self) -> None: self.remove(client_id) except asyncio.CancelledError: break - + + def stop(self): + """ + Closes all Websocket connections with close code 1001 (server + shutting down) and ignores any future calls to `add()`. + """ + # Remove all clients from both dictionaries + client_ids = set(self.desynced.keys()) | set(self.synced.keys()) + clients: list[YjsClient] = [] + for client_id in client_ids: + client = self.desynced.pop(client_id, None) + if not client: + client = self.synced.pop(client_id, None) + if client: + clients.append(client) + + assert len(self.desynced) == 0 + assert len(self.synced) == 0 + + # Close all Websocket connections + for client in clients: + client.websocket.close(code=1001) + + # Set `_stopped` to `True` to ignore future calls to `add()` + self._stopped = True + \ No newline at end of file diff --git a/jupyter_rtc_core/websockets/yroom_ws.py b/jupyter_rtc_core/websockets/yroom_ws.py index c91ba19..6cebc0a 100644 --- a/jupyter_rtc_core/websockets/yroom_ws.py +++ b/jupyter_rtc_core/websockets/yroom_ws.py @@ -14,7 +14,7 @@ class YRoomWebsocket(WebSocketHandler): yroom: YRoom room_id: str - client_id: str + client_id: str | None # TODO: change this. we should pass `self.log` from our # `ExtensionApp` to log messages w/ "RtcCoreExtension" prefix log = logging.Logger("TEMP") @@ -64,8 +64,13 @@ def open(self, *_, **__): raise HTTPError(500, f"Unable to initialize YRoom '{self.room_id}'.") self.yroom = yroom - # Add self as a client to the YRoom + # Add self as a client to the YRoom. + # If `client_id is None`, then the YRoom is being stopped, and this WS + # should be closed immediately w/ close code 1001. self.client_id = self.yroom.clients.add(self) + if not self.client_id: + self.close(code=1001) + return def on_message(self, message: bytes): @@ -73,6 +78,10 @@ def on_message(self, message: bytes): if self.room_id == "JupyterLab:globalAwareness": return + if not self.client_id: + self.close(code=1001) + return + # Route all messages to the YRoom for processing self.yroom.add_message(self.client_id, message) @@ -82,5 +91,6 @@ def on_close(self): if self.room_id == "JupyterLab:globalAwareness": return - self.log.info(f"Closed Websocket to client '{self.client_id}'.") - self.yroom.clients.remove(self.client_id) + if self.client_id: + self.log.info(f"Closed Websocket to client '{self.client_id}'.") + self.yroom.clients.remove(self.client_id)