Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 47 additions & 18 deletions jupyter_rtc_core/rooms/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Literal, Tuple, Any
from jupyter_server_fileid.manager import BaseFileIdManager
from jupyter_server.services.contents.manager import AsyncContentsManager, ContentsManager
from pycrdt import TransactionEvent

class YRoom:
"""A Room to manage all client connection to one notebook file"""
Expand Down Expand Up @@ -69,12 +70,12 @@ def __init__(
type[YBaseDoc],
jupyter_ydoc_classes.get(file_type, jupyter_ydoc_classes["file"])
)
self.jupyter_ydoc = JupyterYDocClass(ydoc=self._ydoc, awareness=self._awareness)
self._jupyter_ydoc = JupyterYDocClass(ydoc=self._ydoc, awareness=self._awareness)

# Initialize YRoomFileAPI and begin loading content
self.file_api = YRoomFileAPI(
room_id=self.room_id,
jupyter_ydoc=self.jupyter_ydoc,
jupyter_ydoc=self._jupyter_ydoc,
log=self.log,
loop=self._loop,
fileid_manager=fileid_manager,
Expand All @@ -88,8 +89,9 @@ def __init__(
self._on_awareness_update
)
self._ydoc_subscription = self._ydoc.observe(
lambda event: self._on_ydoc_update(event.update)
self._on_ydoc_update
)
self._jupyter_ydoc.observe(self._on_jupyter_ydoc_update)

# Initialize message queue and start background task that routes new
# messages in the message queue to the appropriate handler method.
Expand Down Expand Up @@ -117,7 +119,7 @@ async def get_jupyter_ydoc(self):
loaded from the ContentsManager.
"""
await self.file_api.ydoc_content_loaded
return self.jupyter_ydoc
return self._jupyter_ydoc


async def get_ydoc(self):
Expand Down Expand Up @@ -297,10 +299,11 @@ def handle_sync_update(self, client_id: str, message: bytes) -> None:
clients after this method is called via the `self._on_ydoc_update()`
observer.
"""
# Remove client and kill websocket if received SyncUpdate when client is desynced
# If client is desynced and sends a SyncUpdate, that results in a
# protocol error. Close the WebSocket and return early in this case.
if self._should_ignore_update(client_id, "SyncUpdate"):
self.log.error(f"Should not receive SyncUpdate message when double handshake is not completed for client '{client_id}' and room '{self.room_id}'")
self._client_group.remove(client_id)
self.clients.remove(client_id)
return

# Apply the SyncUpdate to the YDoc
try:
Expand All @@ -315,23 +318,33 @@ def handle_sync_update(self, client_id: str, message: bytes) -> None:
return


def _on_ydoc_update(self, message_payload: bytes, client_id: str | None = None) -> None:
def _on_ydoc_update(self, event: TransactionEvent) -> None:
"""
This method is an observer on `self.ydoc` which:
This method is an observer on `self._ydoc` which broadcasts a
`SyncUpdate` message to all synced clients whenever the YDoc changes.

- Broadcasts a SyncUpdate message payload to all connected clients by
writing to their respective WebSockets,
The YDoc is saved in the `self._on_jupyter_ydoc_update()` observer.
"""
update: bytes = event.update
message = pycrdt.create_update_message(update)
self._broadcast_message(message, message_type="SyncUpdate")

- Persists the contents of the updated YDoc by writing to disk.

This method can also be called manually.
def _on_jupyter_ydoc_update(self, updated_key: str, *_) -> None:
"""
# Broadcast the message
message = pycrdt.create_update_message(message_payload)
self._broadcast_message(message, message_type="SyncUpdate")
This method is an observer on `self._jupyter_ydoc` which saves the file
whenever the YDoc changes, unless `updated_key == "state"`.

The `state` key is used by `jupyter_ydoc` to store temporary data like
whether a file is 'dirty' (has unsaved changes). This data is not
persisted, so changes to the `state` key should be ignored. Otherwise,
an infinite loop of saves will result, as saving sets `dirty = False`.

# Save the file to disk
self.file_api.schedule_save()
This observer is separate because `pycrdt.Doc.observe()` does not pass
`updated_key` to `self._on_ydoc_update()`.
"""
if updated_key != "state":
self.file_api.schedule_save()


def handle_awareness_update(self, client_id: str, message: bytes) -> None:
Expand Down Expand Up @@ -377,6 +390,15 @@ def _broadcast_message(self, message: bytes, message_type: Literal['AwarenessUpd
fails. `message_type` is used to produce more readable warnings.
"""
clients = self.clients.get_all()
client_count = len(clients)
if not client_count:
return

if message_type == "SyncUpdate":
self.log.info(
f"Broadcasting SyncUpdate to all {client_count} synced clients."
)

for client in clients:
try:
# TODO: remove this assertion once websocket is made required
Expand All @@ -389,6 +411,12 @@ def _broadcast_message(self, message: bytes, message_type: Literal['AwarenessUpd
f"to client '{client.id}:'"
)
self.log.exception(e)
continue

if message_type == "SyncUpdate":
self.log.info(
f"Broadcast of SyncUpdate complete."
)

def _on_awareness_update(self, type: str, changes: tuple[dict[str, Any], Any]) -> None:
"""
Expand Down Expand Up @@ -417,6 +445,7 @@ async def stop(self) -> None:
# Remove all observers, as updates no longer need to be broadcast
self._ydoc.unobserve(self._ydoc_subscription)
self._awareness.unobserve(self._awareness_subscription)
self._jupyter_ydoc.unobserve()

# Finish processing all messages, then stop the queue to end the
# `_process_message_queue()` background task.
Expand Down
Loading