diff --git a/jupyter_rtc_core/rooms/yroom.py b/jupyter_rtc_core/rooms/yroom.py index 37d557c..34c40ee 100644 --- a/jupyter_rtc_core/rooms/yroom.py +++ b/jupyter_rtc_core/rooms/yroom.py @@ -15,6 +15,7 @@ from typing import Literal, Tuple, Any from jupyter_server_fileid.manager import BaseFileIdManager from jupyter_server.services.contents.manager import AsyncContentsManager, ContentsManager + from pycrdt import TransactionEvent class YRoom: """A Room to manage all client connection to one notebook file""" @@ -69,12 +70,12 @@ def __init__( type[YBaseDoc], jupyter_ydoc_classes.get(file_type, jupyter_ydoc_classes["file"]) ) - self.jupyter_ydoc = JupyterYDocClass(ydoc=self._ydoc, awareness=self._awareness) + self._jupyter_ydoc = JupyterYDocClass(ydoc=self._ydoc, awareness=self._awareness) # Initialize YRoomFileAPI and begin loading content self.file_api = YRoomFileAPI( room_id=self.room_id, - jupyter_ydoc=self.jupyter_ydoc, + jupyter_ydoc=self._jupyter_ydoc, log=self.log, loop=self._loop, fileid_manager=fileid_manager, @@ -88,8 +89,9 @@ def __init__( self._on_awareness_update ) self._ydoc_subscription = self._ydoc.observe( - lambda event: self._on_ydoc_update(event.update) + self._on_ydoc_update ) + self._jupyter_ydoc.observe(self._on_jupyter_ydoc_update) # Initialize message queue and start background task that routes new # messages in the message queue to the appropriate handler method. @@ -117,7 +119,7 @@ async def get_jupyter_ydoc(self): loaded from the ContentsManager. """ await self.file_api.ydoc_content_loaded - return self.jupyter_ydoc + return self._jupyter_ydoc async def get_ydoc(self): @@ -297,10 +299,11 @@ def handle_sync_update(self, client_id: str, message: bytes) -> None: clients after this method is called via the `self._on_ydoc_update()` observer. """ - # Remove client and kill websocket if received SyncUpdate when client is desynced + # If client is desynced and sends a SyncUpdate, that results in a + # protocol error. Close the WebSocket and return early in this case. if self._should_ignore_update(client_id, "SyncUpdate"): - self.log.error(f"Should not receive SyncUpdate message when double handshake is not completed for client '{client_id}' and room '{self.room_id}'") - self._client_group.remove(client_id) + self.clients.remove(client_id) + return # Apply the SyncUpdate to the YDoc try: @@ -315,23 +318,33 @@ def handle_sync_update(self, client_id: str, message: bytes) -> None: return - def _on_ydoc_update(self, message_payload: bytes, client_id: str | None = None) -> None: + def _on_ydoc_update(self, event: TransactionEvent) -> None: """ - This method is an observer on `self.ydoc` which: + This method is an observer on `self._ydoc` which broadcasts a + `SyncUpdate` message to all synced clients whenever the YDoc changes. - - Broadcasts a SyncUpdate message payload to all connected clients by - writing to their respective WebSockets, + The YDoc is saved in the `self._on_jupyter_ydoc_update()` observer. + """ + update: bytes = event.update + message = pycrdt.create_update_message(update) + self._broadcast_message(message, message_type="SyncUpdate") - - Persists the contents of the updated YDoc by writing to disk. - This method can also be called manually. + def _on_jupyter_ydoc_update(self, updated_key: str, *_) -> None: """ - # Broadcast the message - message = pycrdt.create_update_message(message_payload) - self._broadcast_message(message, message_type="SyncUpdate") + This method is an observer on `self._jupyter_ydoc` which saves the file + whenever the YDoc changes, unless `updated_key == "state"`. + + The `state` key is used by `jupyter_ydoc` to store temporary data like + whether a file is 'dirty' (has unsaved changes). This data is not + persisted, so changes to the `state` key should be ignored. Otherwise, + an infinite loop of saves will result, as saving sets `dirty = False`. - # Save the file to disk - self.file_api.schedule_save() + This observer is separate because `pycrdt.Doc.observe()` does not pass + `updated_key` to `self._on_ydoc_update()`. + """ + if updated_key != "state": + self.file_api.schedule_save() def handle_awareness_update(self, client_id: str, message: bytes) -> None: @@ -377,6 +390,15 @@ def _broadcast_message(self, message: bytes, message_type: Literal['AwarenessUpd fails. `message_type` is used to produce more readable warnings. """ clients = self.clients.get_all() + client_count = len(clients) + if not client_count: + return + + if message_type == "SyncUpdate": + self.log.info( + f"Broadcasting SyncUpdate to all {client_count} synced clients." + ) + for client in clients: try: # TODO: remove this assertion once websocket is made required @@ -389,6 +411,12 @@ def _broadcast_message(self, message: bytes, message_type: Literal['AwarenessUpd f"to client '{client.id}:'" ) self.log.exception(e) + continue + + if message_type == "SyncUpdate": + self.log.info( + f"Broadcast of SyncUpdate complete." + ) def _on_awareness_update(self, type: str, changes: tuple[dict[str, Any], Any]) -> None: """ @@ -417,6 +445,7 @@ async def stop(self) -> None: # Remove all observers, as updates no longer need to be broadcast self._ydoc.unobserve(self._ydoc_subscription) self._awareness.unobserve(self._awareness_subscription) + self._jupyter_ydoc.unobserve() # Finish processing all messages, then stop the queue to end the # `_process_message_queue()` background task.