Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions jupyter_server_ydoc/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
except ModuleNotFoundError:
raise ModuleNotFoundError("Jupyter Server must be installed to use this extension.")

from traitlets import Float, Int, Type
from ypy_websocket.ystore import BaseYStore # type: ignore
from traitlets import Float, Int, Type, Unicode, observe
from ypy_websocket.ystore import BaseYStore, SQLiteYStore # type: ignore

from .handlers import JupyterSQLiteYStore, YDocRoomIdHandler, YDocWebSocketHandler

Expand Down Expand Up @@ -49,6 +49,39 @@ class YDocExtension(ExtensionApp):
directory, and clears history every 24 hours.""",
)

sqlite_ystore_db_path = Unicode(
".jupyter_ystore.db",
config=True,
help="""The path to the YStore database. Defaults to '.jupyter_ystore.db' in the current
directory. Only applicable if the YStore is an SQLiteYStore.""",
)

@observe("sqlite_ystore_db_path")
def _observe_sqlite_ystore_db_path(self, change):
if issubclass(self.ystore_class, SQLiteYStore):
self.ystore_class.db_path = change["new"]
else:
raise RuntimeError(
f"ystore_class must be an SQLiteYStore to be able to set sqlite_ystore_db_path, not {self.ystore_class}"
)

sqlite_ystore_document_ttl = Int(
None,
allow_none=True,
config=True,
help="""The document time-to-live in seconds. Defaults to None (document history is never
cleared). Only applicable if the YStore is an SQLiteYStore.""",
)

@observe("sqlite_ystore_document_ttl")
def _observe_sqlite_ystore_document_ttl(self, change):
if issubclass(self.ystore_class, SQLiteYStore):
self.ystore_class.document_ttl = change["new"]
else:
raise RuntimeError(
f"ystore_class must be an SQLiteYStore to be able to set sqlite_ystore_document_ttl, not {self.ystore_class}"
)

def initialize_settings(self):
self.settings.update(
{
Expand Down
39 changes: 22 additions & 17 deletions jupyter_server_ydoc/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from jupyter_ydoc import ydocs as YDOCS # type: ignore
from tornado import web
from tornado.websocket import WebSocketHandler
from ypy_websocket import WebsocketServer, YMessageType, YRoom # type: ignore
from ypy_websocket.websocket_server import WebsocketServer, YRoom # type: ignore
from ypy_websocket.ystore import ( # type: ignore
BaseYStore,
SQLiteYStore,
TempFileYStore,
YDocNotFound,
)
from ypy_websocket.yutils import YMessageType # type: ignore

YFILE = YDOCS["file"]

Expand All @@ -30,14 +31,12 @@ class JupyterTempFileYStore(TempFileYStore):

class JupyterSQLiteYStore(SQLiteYStore):
db_path = ".jupyter_ystore.db"
document_ttl = 24 * 60 * 60
document_ttl = None


class DocumentRoom(YRoom):
"""A Y room for a possibly stored document (e.g. a notebook)."""

is_transient = False

def __init__(self, type: str, ystore: BaseYStore, log: Optional[Logger]):
super().__init__(ready=False, ystore=ystore, log=log)
self.type = type
Expand All @@ -49,8 +48,6 @@ def __init__(self, type: str, ystore: BaseYStore, log: Optional[Logger]):
class TransientRoom(YRoom):
"""A Y room for sharing state (e.g. awareness)."""

is_transient = True

def __init__(self, log: Optional[Logger]):
super().__init__(log=log)

Expand Down Expand Up @@ -132,6 +129,7 @@ async def __anext__(self):

def get_file_info(self) -> Tuple[str, str, str]:
assert self.websocket_server is not None
assert isinstance(self.room, DocumentRoom)
room_name = self.websocket_server.get_room_name(self.room)
file_format: str
file_type: str
Expand Down Expand Up @@ -175,10 +173,10 @@ async def open(self, path):
asyncio.create_task(self.websocket_server.serve(self))

# cancel the deletion of the room if it was scheduled
if not self.room.is_transient and self.room.cleaner is not None:
if isinstance(self.room, DocumentRoom) and self.room.cleaner is not None:
self.room.cleaner.cancel()

if not self.room.is_transient and not self.room.ready:
if isinstance(self.room, DocumentRoom) and not self.room.ready:
file_format, file_type, file_path = self.get_file_info()
self.log.debug("Opening Y document from disk: %s", file_path)
model = await ensure_async(
Expand All @@ -188,26 +186,30 @@ async def open(self, path):
# check again if ready, because loading the file can be async
if not self.room.ready:
# try to apply Y updates from the YStore for this document
try:
await self.room.ystore.apply_updates(self.room.ydoc)
read_from_source = False
except YDocNotFound:
# YDoc not found in the YStore, create the document from the source file (no change history)
read_from_source = True
read_from_source = True
if self.room.ystore is not None:
try:
await self.room.ystore.apply_updates(self.room.ydoc)
read_from_source = False
except YDocNotFound:
# YDoc not found in the YStore, create the document from the source file (no change history)
pass
if not read_from_source:
# if YStore updates and source file are out-of-sync, resync updates with source
if self.room.document.source != model["content"]:
read_from_source = True
if read_from_source:
self.room.document.source = model["content"]
await self.room.ystore.encode_state_as_update(self.room.ydoc)
if self.room.ystore:
await self.room.ystore.encode_state_as_update(self.room.ydoc)
self.room.document.dirty = False
self.room.ready = True
self.room.watcher = asyncio.create_task(self.watch_file())
# save the document when changed
self.room.document.observe(self.on_document_change)

async def watch_file(self):
assert isinstance(self.room, DocumentRoom)
poll_interval = self.settings["collaborative_file_poll_interval"]
if not poll_interval:
self.room.watcher = None
Expand All @@ -217,6 +219,7 @@ async def watch_file(self):
await self.maybe_load_document()

async def maybe_load_document(self):
assert isinstance(self.room, DocumentRoom)
file_format, file_type, file_path = self.get_file_info()
async with self.lock:
model = await ensure_async(
Expand Down Expand Up @@ -267,7 +270,7 @@ def on_message(self, message):
# filter out message depending on changes
if skip:
self.log.debug(
"Filtered out Y message of type: %s", YMessageType(message_type).raw_str()
"Filtered out Y message of type: %s", YMessageType(message_type).name
)
return skip
self._message_queue.put_nowait(message)
Expand All @@ -276,12 +279,13 @@ def on_message(self, message):
def on_close(self) -> None:
# stop serving this client
self._message_queue.put_nowait(b"")
if not self.room.is_transient and self.room.clients == [self]:
if isinstance(self.room, DocumentRoom) and self.room.clients == [self]:
# no client in this room after we disconnect
# keep the document for a while in case someone reconnects
self.room.cleaner = asyncio.create_task(self.clean_room())

async def clean_room(self) -> None:
assert isinstance(self.room, DocumentRoom)
seconds = self.settings["collaborative_document_cleanup_delay"]
if seconds is None:
return
Expand Down Expand Up @@ -309,6 +313,7 @@ def on_document_change(self, event):
self.saving_document = asyncio.create_task(self.maybe_save_document())

async def maybe_save_document(self):
assert isinstance(self.room, DocumentRoom)
seconds = self.settings["collaborative_document_save_delay"]
if seconds is None:
return
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
requires-python = ">=3.7"
dependencies = [
"jupyter_ydoc>=0.2.0,<0.4.0",
"ypy-websocket>=0.8.0,<0.9.0",
"ypy-websocket>=0.8.1,<0.9.0",
"jupyter_server_fileid >=0.6.0,<1"
]

Expand Down