diff --git a/projects/jupyter-server-ydoc/jupyter_server_ydoc/app.py b/projects/jupyter-server-ydoc/jupyter_server_ydoc/app.py index b1d73d59..e981f13e 100644 --- a/projects/jupyter-server-ydoc/jupyter_server_ydoc/app.py +++ b/projects/jupyter-server-ydoc/jupyter_server_ydoc/app.py @@ -4,7 +4,8 @@ import asyncio from functools import partial -from typing import Literal +from typing import Literal, cast +from collections import defaultdict from jupyter_server.extension.application import ExtensionApp from jupyter_ydoc import ydocs as YDOCS @@ -29,6 +30,7 @@ FORK_EVENTS_SCHEMA_PATH, encode_file_path, room_id_from_encoded_path, + decode_file_path, ) from .websocketserver import JupyterWebsocketServer, RoomNotFound, exception_logger @@ -83,6 +85,8 @@ class YDocExtension(ExtensionApp): model.""", ) + _room_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + def initialize(self): super().initialize() self.serverapp.event_logger.register_event_schema(EVENTS_SCHEMA_PATH) @@ -142,6 +146,7 @@ def initialize_handlers(self): "file_loaders": self.file_loaders, "ystore_class": ystore_class, "ywebsocket_server": self.ywebsocket_server, + "room_locks": self._room_locks, }, ), (r"/api/collaboration/session/(.*)", DocSessionHandler), @@ -171,6 +176,7 @@ async def get_document( file_format: Literal["json", "text"] | None = None, room_id: str | None = None, copy: bool = True, + create: bool = False, ) -> YBaseDoc | None: """Get a view of the shared model for the matching document. @@ -179,6 +185,8 @@ async def get_document( If `copy=True`, the returned shared model is a fork, meaning that any changes made to it will not be propagated to the shared model used by the application. + + If `create=True`, the room will be created if it doesn't exist. """ error_msg = "You need to provide either a ``room_id`` or the ``path``, the ``content_type`` and the ``file_format``." if room_id is None: @@ -193,13 +201,54 @@ async def get_document( elif path is not None or content_type is not None or file_format is not None: raise ValueError(error_msg) - else: - room_id = room_id - try: - room = await self.ywebsocket_server.get_room(room_id) - except RoomNotFound: - return None + if not self.ywebsocket_server.started.is_set(): + asyncio.create_task(self.ywebsocket_server.start()) + await self.ywebsocket_server.started.wait() + async with self._room_locks[room_id]: + try: + room = await self.ywebsocket_server.get_room(room_id) + except RoomNotFound: + if not create: + return None + + file_format_str, file_type, file_id = decode_file_path(room_id) + # cast down so mypy won’t complain when we pass this into DocumentRoom + file_format = cast(Literal["json", "text"], file_format_str) + updates_file_path = f".{file_type}:{file_id}.y" + ystore = self.ystore_class( + path=updates_file_path, + log=self.log, + ) + # Create a new room + room = DocumentRoom( + room_id, + file_format, + file_type, + self.file_loaders[file_id], + self.serverapp.event_logger, + ystore, + self.log, + exception_handler=exception_logger, + save_delay=self.document_save_delay, + ) + await room.initialize() + try: + await self.ywebsocket_server.start_room(room) + self.ywebsocket_server.add_room(room_id, room) + self.log.info(f"Created and started room: {room_id}") + except Exception as e: + self.log.error("Room %s failed to start on websocket server", room_id) + # Clean room + await room.stop() + self.log.info("Room %s deleted", room_id) + file = self.file_loaders[file_id] + if file.number_of_subscriptions == 0 or ( + file.number_of_subscriptions == 1 and room_id in file._subscriptions + ): + self.log.info("Deleting file %s", file.path) + await self.file_loaders.remove(file_id) + raise e if isinstance(room, DocumentRoom): if copy: diff --git a/projects/jupyter-server-ydoc/jupyter_server_ydoc/handlers.py b/projects/jupyter-server-ydoc/jupyter_server_ydoc/handlers.py index 940370a0..ceef2aa1 100644 --- a/projects/jupyter-server-ydoc/jupyter_server_ydoc/handlers.py +++ b/projects/jupyter-server-ydoc/jupyter_server_ydoc/handlers.py @@ -66,7 +66,6 @@ class YDocWebSocketHandler(WebSocketHandler, JupyterHandler): _message_queue: asyncio.Queue[Any] _background_tasks: set[asyncio.Task] - _room_locks: dict[str, asyncio.Lock] = {} def _room_lock(self, room_id: str) -> asyncio.Lock: if room_id not in self._room_locks: @@ -173,6 +172,7 @@ def initialize( ywebsocket_server: JupyterWebsocketServer, file_loaders: FileLoaderMapping, ystore_class: type[BaseYStore], + room_locks: dict[str, asyncio.Lock] | None = None, document_cleanup_delay: float | None = 60.0, document_save_delay: float | None = 1.0, ) -> None: @@ -187,6 +187,7 @@ def initialize( self._message_queue = asyncio.Queue() self._room_id = "" self.room = None # type:ignore + self._room_locks = room_locks if room_locks is not None else {} @property def path(self): diff --git a/tests/test_app.py b/tests/test_app.py index b46b2bb7..235630a7 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -88,6 +88,24 @@ async def test_get_document_file(rtc_create_file, jp_serverapp, copy): await collaboration.stop_extension() +async def test_get_document_create_room(rtc_create_file, jp_serverapp): + path, content = await rtc_create_file("test.txt", "test", index=True) + collaboration = jp_serverapp.web_app.settings["jupyter_server_ydoc"] + # Document doesn't exist initially + document_before = await collaboration.get_document( + path=path, content_type="file", file_format="text", create=False + ) + assert document_before is None + + document_after = await collaboration.get_document( + path=path, content_type="file", file_format="text", create=True + ) + # Verify document was created and has correct content + assert document_after is not None + assert document_after.get() == content == "test" + await collaboration.stop_extension() + + @pytest.mark.parametrize("copy", [True, False]) async def test_get_document_notebook(rtc_create_notebook, jp_serverapp, copy): nb = nbformat.v4.new_notebook(