Skip to content

Modify API to optionally create a document room if it does not exist. #489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions projects/jupyter-server-ydoc/jupyter_server_ydoc/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import asyncio
from functools import partial
from typing import Literal
from typing import Literal, cast

from jupyter_server.extension.application import ExtensionApp
from jupyter_ydoc import ydocs as YDOCS
Expand All @@ -29,6 +29,7 @@
FORK_EVENTS_SCHEMA_PATH,
encode_file_path,
room_id_from_encoded_path,
decode_file_path,
)
from .websocketserver import JupyterWebsocketServer, RoomNotFound, exception_logger

Expand Down Expand Up @@ -83,6 +84,8 @@ class YDocExtension(ExtensionApp):
model.""",
)

_room_locks: dict[str, asyncio.Lock] = {}

def initialize(self):
super().initialize()
self.serverapp.event_logger.register_event_schema(EVENTS_SCHEMA_PATH)
Expand Down Expand Up @@ -142,6 +145,7 @@ def initialize_handlers(self):
"file_loaders": self.file_loaders,
"ystore_class": ystore_class,
"ywebsocket_server": self.ywebsocket_server,
"_room_locks": self._room_locks,
},
),
(r"/api/collaboration/session/(.*)", DocSessionHandler),
Expand All @@ -163,6 +167,11 @@ def initialize_handlers(self):
]
)

def _room_lock(self, room_id: str) -> asyncio.Lock:
if room_id not in self._room_locks:
self._room_locks[room_id] = asyncio.Lock()
return self._room_locks[room_id]

async def get_document(
self: YDocExtension,
*,
Expand All @@ -171,6 +180,7 @@ async def get_document(
file_format: Literal["json", "text"] | None = None,
room_id: str | None = None,
copy: bool = True,
create: bool = False,
) -> YBaseDoc | None:
"""Get a view of the shared model for the matching document.

Expand All @@ -179,6 +189,8 @@ async def get_document(

If `copy=True`, the returned shared model is a fork, meaning that any changes
made to it will not be propagated to the shared model used by the application.

If `create=True`, the room will be created if it doesn't exist.
"""
error_msg = "You need to provide either a ``room_id`` or the ``path``, the ``content_type`` and the ``file_format``."
if room_id is None:
Expand All @@ -193,13 +205,54 @@ async def get_document(

elif path is not None or content_type is not None or file_format is not None:
raise ValueError(error_msg)
else:
room_id = room_id

try:
room = await self.ywebsocket_server.get_room(room_id)
except RoomNotFound:
return None
if not self.ywebsocket_server.started.is_set():
asyncio.create_task(self.ywebsocket_server.start())
await self.ywebsocket_server.started.wait()
async with self._room_lock(room_id):
try:
room = await self.ywebsocket_server.get_room(room_id)
except RoomNotFound:
if not create:
return None

file_format_str, file_type, file_id = decode_file_path(room_id)
# cast down so mypy won’t complain when we pass this into DocumentRoom
file_format = cast(Literal["json", "text"], file_format_str)
updates_file_path = f".{file_type}:{file_id}.y"
ystore = self.ystore_class(
path=updates_file_path,
log=self.log,
)
# Create a new room
room = DocumentRoom(
room_id,
file_format,
file_type,
self.file_loaders[file_id],
self.serverapp.event_logger,
ystore,
self.log,
exception_handler=exception_logger,
save_delay=self.document_save_delay,
)
await room.initialize()
try:
await self.ywebsocket_server.start_room(room)
self.ywebsocket_server.add_room(room_id, room)
self.log.info(f"Created and started room: {room_id}")
except Exception as e:
self.log.error("Room %s failed to start on websocket server", room_id)
# Clean room
await room.stop()
self.log.info("Room %s deleted", room_id)
file = self.file_loaders[file_id]
if file.number_of_subscriptions == 0 or (
file.number_of_subscriptions == 1 and room_id in file._subscriptions
):
self.log.info("Deleting file %s", file.path)
await self.file_loaders.remove(file_id)
raise e

if isinstance(room, DocumentRoom):
if copy:
Expand Down
3 changes: 2 additions & 1 deletion projects/jupyter-server-ydoc/jupyter_server_ydoc/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class YDocWebSocketHandler(WebSocketHandler, JupyterHandler):

_message_queue: asyncio.Queue[Any]
_background_tasks: set[asyncio.Task]
_room_locks: dict[str, asyncio.Lock] = {}

def _room_lock(self, room_id: str) -> asyncio.Lock:
if room_id not in self._room_locks:
Expand Down Expand Up @@ -173,6 +172,7 @@ def initialize(
ywebsocket_server: JupyterWebsocketServer,
file_loaders: FileLoaderMapping,
ystore_class: type[BaseYStore],
_room_locks: dict[str, asyncio.Lock],
document_cleanup_delay: float | None = 60.0,
document_save_delay: float | None = 1.0,
) -> None:
Expand All @@ -187,6 +187,7 @@ def initialize(
self._message_queue = asyncio.Queue()
self._room_id = ""
self.room = None # type:ignore
self._room_locks = _room_locks

@property
def path(self):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,24 @@ async def test_get_document_file(rtc_create_file, jp_serverapp, copy):
await collaboration.stop_extension()


async def test_get_document_create_room(rtc_create_file, jp_serverapp):
path, content = await rtc_create_file("test.txt", "test", index=True)
collaboration = jp_serverapp.web_app.settings["jupyter_server_ydoc"]
# Document doesn't exist initially
document_before = await collaboration.get_document(
path=path, content_type="file", file_format="text", create=False
)
assert document_before is None

document_after = await collaboration.get_document(
path=path, content_type="file", file_format="text", create=True
)
# Verify document was created and has correct content
assert document_after is not None
assert document_after.get() == content == "test"
await collaboration.stop_extension()


@pytest.mark.parametrize("copy", [True, False])
async def test_get_document_notebook(rtc_create_notebook, jp_serverapp, copy):
nb = nbformat.v4.new_notebook(
Expand Down
Loading