Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,7 @@ dmypy.json

# Yarn cache
.yarn/

# Ignore files used to test in real-time
untitled*
Untitled*
11 changes: 11 additions & 0 deletions devenv-jcollab-frontend.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: rtccore-jcollab-frontend
channels:
- conda-forge
dependencies:
- python
- nodejs=22
- uv
- jupyterlab
- pip:
- jupyter_docprovider>=2.0.2,<3
- jupyter_collaboration_ui>=2.0.2,<3
13 changes: 8 additions & 5 deletions jupyter_rtc_core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from traitlets.config import Config
import asyncio

from traitlets import Instance
from traitlets import Type
from .handlers import RouteHandler
from traitlets import Instance, Type
from .handlers import RouteHandler, YRoomSessionHandler
from .websockets import GlobalAwarenessWebsocket, YRoomWebsocket
from .rooms.yroom_manager import YRoomManager

Expand All @@ -20,7 +19,9 @@ class RtcExtensionApp(ExtensionApp):
# global awareness websocket
# (r"api/collaboration/room/JupyterLab:globalAwareness/?", GlobalAwarenessWebsocket),
# # ydoc websocket
# (r"api/collaboration/room/(.*)", YRoomWebsocket)
(r"api/collaboration/room/(.*)", YRoomWebsocket),
# handler that just adds compatibility with Jupyter Collaboration's frontend
(r"api/collaboration/session/(.*)", YRoomSessionHandler)
]

yroom_manager_class = Type(
Expand All @@ -44,7 +45,8 @@ def initialize_settings(self):
# We cannot access the 'file_id_manager' key immediately because server
# extensions initialize in alphabetical order. 'jupyter_rtc_core' <
# 'jupyter_server_fileid'.
get_fileid_manager = lambda: self.settings["file_id_manager"]
def get_fileid_manager():
return self.serverapp.web_app.settings["file_id_manager"]
contents_manager = self.serverapp.contents_manager
loop = asyncio.get_event_loop_policy().get_event_loop()
log = self.log
Expand All @@ -56,6 +58,7 @@ def initialize_settings(self):
loop=loop,
log=log
)
pass


def _link_jupyter_server_extension(self, server_app):
Expand Down
30 changes: 30 additions & 0 deletions jupyter_rtc_core/handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import uuid

from jupyter_server.base.handlers import APIHandler
import tornado
Expand All @@ -14,4 +15,33 @@ def get(self):
}))


# TODO: remove this by v1.0.0 if deemed unnecessary. Just adding this for
# compatibility with the `jupyter_collaboration` frontend.
class YRoomSessionHandler(APIHandler):
SESSION_ID = str(uuid.uuid4())

@tornado.web.authenticated
def put(self, path):
body = json.loads(self.request.body)
format = body["format"]
content_type = body["type"]
# self.log.info("IN HANDLER")
# for k, v in self.settings.items():
# print(f"{k}: {v}")
# print(len(self.settings.items()))
# print(id(self.settings))

file_id_manager = self.settings["file_id_manager"]
file_id = file_id_manager.index(path)

data = json.dumps(
{
"format": format,
"type": content_type,
"fileId": file_id,
"sessionId": self.SESSION_ID,
}
)
self.set_status(200)
self.finish(data)

43 changes: 30 additions & 13 deletions jupyter_rtc_core/rooms/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .yroom_file_api import YRoomFileAPI

if TYPE_CHECKING:
from typing import Literal, Tuple
from typing import Literal, Tuple, Any
from jupyter_server_fileid.manager import BaseFileIdManager
from jupyter_server.services.contents.manager import AsyncContentsManager, ContentsManager

Expand All @@ -22,7 +22,11 @@ class YRoom:
log: Logger
"""Log object"""
room_id: str
"""Room Id"""
"""
The ID of the room. This is a composite ID following the format:

room_id := "{file_type}:{file_format}:{file_id}"
"""

_jupyter_ydoc: YBaseDoc
"""JupyterYDoc"""
Expand Down Expand Up @@ -60,9 +64,10 @@ def __init__(
self._client_group = YjsClientGroup(room_id=room_id, log=self.log, loop=self._loop)
self._ydoc = pycrdt.Doc()
self._awareness = pycrdt.Awareness(ydoc=self._ydoc)
_, file_type, _ = self.room_id.split(":")
JupyterYDocClass = cast(
type[YBaseDoc],
jupyter_ydoc_classes.get(self.file_type, jupyter_ydoc_classes["file"])
jupyter_ydoc_classes.get(file_type, jupyter_ydoc_classes["file"])
)
self.jupyter_ydoc = JupyterYDocClass(ydoc=self._ydoc, awareness=self._awareness)

Expand Down Expand Up @@ -90,6 +95,9 @@ def __init__(
# messages in the message queue to the appropriate handler method.
self._message_queue = asyncio.Queue()
self._loop.create_task(self._on_new_message())

# Log notification that room is ready
self.log.info(f"Room '{self.room_id}' initialized.")


@property
Expand Down Expand Up @@ -157,20 +165,28 @@ async def _on_new_message(self) -> None:
# Handle Awareness messages
message_type = message[0]
if message_type == YMessageType.AWARENESS:
self.handle_awareness_update(client_id, message[1:])
self.log.debug(f"Received AwarenessUpdate from '{client_id}'.")
self.handle_awareness_update(client_id, message)
self.log.debug(f"Handled AwarenessUpdate from '{client_id}'.")
continue

# Handle Sync messages
assert message_type == YMessageType.SYNC
message_subtype = message[1] if len(message) >= 2 else None
if message_subtype == YSyncMessageSubtype.SYNC_STEP1:
self.log.info(f"Received SS1 from '{client_id}'.")
self.handle_sync_step1(client_id, message)
self.log.info(f"Handled SS1 from '{client_id}'.")
continue
elif message_subtype == YSyncMessageSubtype.SYNC_STEP2:
self.log.info(f"Received SS2 from '{client_id}'.")
self.handle_sync_step2(client_id, message)
self.log.info(f"Handled SS2 from '{client_id}'.")
continue
elif message_subtype == YSyncMessageSubtype.SYNC_UPDATE:
self.log.info(f"Received SyncUpdate from '{client_id}'.")
self.handle_sync_update(client_id, message)
self.log.info(f"Handled SyncUpdate from '{client_id}'.")
continue
else:
self.log.warning(
Expand Down Expand Up @@ -213,7 +229,8 @@ def handle_sync_step1(self, client_id: str, message: bytes) -> None:
try:
# TODO: remove the assert once websocket is made required
assert isinstance(new_client.websocket, WebSocketHandler)
new_client.websocket.write_message(sync_step2_message)
new_client.websocket.write_message(sync_step2_message, binary=True)
self.log.info(f"Sent SS2 reply to client '{client_id}'.")
except Exception as e:
self.log.error(
"An exception occurred when writing the SyncStep2 reply "
Expand All @@ -228,7 +245,8 @@ def handle_sync_step1(self, client_id: str, message: bytes) -> None:
try:
assert isinstance(new_client.websocket, WebSocketHandler)
sync_step1_message = pycrdt.create_sync_message(self._ydoc)
new_client.websocket.write_message(sync_step1_message)
new_client.websocket.write_message(sync_step1_message, binary=True)
self.log.info(f"Sent SS1 message to client '{client_id}'.")
except Exception as e:
self.log.error(
"An exception occurred when writing a SyncStep1 message "
Expand Down Expand Up @@ -295,23 +313,22 @@ def write_sync_update(self, message_payload: bytes, client_id: str | None = None

This method can also be called manually.
"""
# Broadcast the message:
# Broadcast the message
message = pycrdt.create_update_message(message_payload)
self._broadcast_message(message, message_type="SyncUpdate")

# Save the file to disk.
# TODO: requires YRoomLoader implementation
return
# Save the file to disk
self.file_api.schedule_save()


def handle_awareness_update(self, client_id: str, message: bytes) -> None:
# Apply the AwarenessUpdate message
try:
message_payload = message[1:]
message_payload = pycrdt.read_message(message[1:])
self._awareness.apply_awareness_update(message_payload, origin=self)
except Exception as e:
self.log.error(
"An exception occurred when applying an AwarenessUpdate"
"An exception occurred when applying an AwarenessUpdate "
f"message from client '{client_id}':"
)
self.log.exception(e)
Expand Down Expand Up @@ -351,7 +368,7 @@ def _broadcast_message(self, message: bytes, message_type: Literal['AwarenessUpd
try:
# TODO: remove this assertion once websocket is made required
assert isinstance(client.websocket, WebSocketHandler)
client.websocket.write_message(message)
client.websocket.write_message(message, binary=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent an hour trying to figure out why the client was auto-closing every WS right after we write a message to it. Turns out, you need to pass binary=True to write byte streams. Otherwise, Tornado will automatically encode byte streams as UTF-8, which the client will reject with an opaque error log. 🤦

From the documentation:

def write_message(
self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False
) -> "Future[None]":
"""Sends the given message to the client of this Web Socket.

The message may be either a string or a dict (which will be
encoded as json). If the binary argument is false, the
message will be sent as utf8; in binary mode any byte string
is allowed.

except Exception as e:
self.log.warning(
f"An exception occurred when broadcasting a "
Expand Down
10 changes: 10 additions & 0 deletions jupyter_rtc_core/rooms/yroom_file_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class YRoomFileAPI:

# See `filemanager.py` in `jupyter_server` for references on supported file
# formats & file types.
room_id: str
file_format: Literal["text", "base64"]
file_type: Literal["file", "notebook"]
file_id: str
Expand All @@ -57,6 +58,7 @@ def __init__(
loop: asyncio.AbstractEventLoop
):
# Bind instance attributes
self.room_id = room_id
self.file_format, self.file_type, self.file_id = room_id.split(":")
self.jupyter_ydoc = jupyter_ydoc
self.log = log
Expand Down Expand Up @@ -113,6 +115,8 @@ def load_ydoc_content(self) -> None:
# Otherwise, set loading to `True` and start the loading task.
if self._ydoc_content_loaded.is_set() or self._ydoc_content_loading:
return

self.log.info(f"Loading content for room ID '{self.room_id}'.")
self._ydoc_content_loading = True
self._loop.create_task(self._load_ydoc_content())

Expand All @@ -134,6 +138,7 @@ async def _load_ydoc_content(self) -> None:
# Also set loading to `False` for consistency
self._ydoc_content_loaded.set()
self._ydoc_content_loading = False
self.log.info(f"Loaded content for room ID '{self.room_id}'.")


def schedule_save(self) -> None:
Expand Down Expand Up @@ -172,6 +177,7 @@ async def _process_scheduled_saves(self) -> None:
file_format = self.file_format
file_type = self.file_type if self.file_type in SAVEABLE_FILE_TYPES else "file"

# Save the YDoc via the ContentsManager
await ensure_async(self._contents_manager.save(
{
"format": file_format,
Expand All @@ -180,6 +186,10 @@ async def _process_scheduled_saves(self) -> None:
},
path
))

# Mark 'dirty' as `False`. This hides the "unsaved changes" icon
# in the JupyterLab tab rendering this YDoc in the frontend.
self.jupyter_ydoc.dirty = False
except Exception as e:
self.log.error("An exception occurred when saving JupyterYDoc.")
self.log.exception(e)
Expand Down
8 changes: 4 additions & 4 deletions jupyter_rtc_core/rooms/yroom_manager.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from traitlets import HasTraits, Instance, default
from __future__ import annotations

from .yroom import YRoom
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import asyncio
import logging
from typing import Callable
from jupyter_server_fileid.manager import BaseFileIdManager
from jupyter_server.services.contents.manager import AsyncContentsManager, ContentsManager

Expand All @@ -18,7 +17,7 @@ class YRoomManager():
def __init__(
self,
*,
get_fileid_manager: Callable[[], BaseFileIdManager],
get_fileid_manager: callable[[], BaseFileIdManager],
contents_manager: AsyncContentsManager | ContentsManager,
loop: asyncio.AbstractEventLoop,
log: logging.Logger,
Expand All @@ -35,7 +34,7 @@ def __init__(
@property
def fileid_manager(self) -> BaseFileIdManager:
return self._get_fileid_manager()


def get_room(self, room_id: str) -> YRoom | None:
"""
Expand All @@ -49,6 +48,7 @@ def get_room(self, room_id: str) -> YRoom | None:

# Otherwise, create a new room
try:
self.log.info(f"Initializing room '{room_id}'.")
yroom = YRoom(
room_id=room_id,
log=self.log,
Expand Down
2 changes: 1 addition & 1 deletion jupyter_rtc_core/websockets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .global_awareness_ws import GlobalAwarenessWebsocket
from .yroom_ws import YRoomWebsocket
from .clients import YjsClient, YjsClientGroup
from .yroom_ws import YRoomWebsocket
11 changes: 5 additions & 6 deletions jupyter_rtc_core/websockets/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
"""

from __future__ import annotations
from datetime import timedelta, timezone
import datetime
from datetime import timedelta, timezone, datetime
from logging import Logger
from typing import TYPE_CHECKING
import uuid
Expand Down Expand Up @@ -43,7 +42,7 @@ def synced(self):

@synced.setter
def synced(self, v: bool):
self.synced = v
self._synced = v
self.last_modified = datetime.now(timezone.utc)

class YjsClientGroup:
Expand All @@ -67,7 +66,7 @@ class YjsClientGroup:
"""Log object"""
loop: asyncio.AbstractEventLoop
"""Event loop"""
poll_interval_seconds: int
_poll_interval_seconds: int
"""The poll time interval used while auto removing desynced clients"""
desynced_timeout_seconds: int
"""The max time period in seconds that a desynced client does not become synced before get auto removed from desynced dict"""
Expand All @@ -78,8 +77,8 @@ def __init__(self, *, room_id: str, log: Logger, loop: asyncio.AbstractEventLoop
self.desynced: dict[str, YjsClient] = {}
self.log = log
self.loop = loop
self.loop.create_task(self._clean_desynced())
self.poll_interval_seconds = poll_interval_seconds
# self.loop.create_task(self._clean_desynced())
self._poll_interval_seconds = poll_interval_seconds
self.desynced_timeout_seconds = desynced_timeout_seconds

def add(self, websocket: WebSocketHandler) -> str:
Expand Down
Loading
Loading