11from __future__ import annotations # see PEP-563 for motivation behind this
2- from typing import TYPE_CHECKING
2+ from typing import TYPE_CHECKING , cast
33from logging import Logger
44import asyncio
55from ..websockets import YjsClientGroup
66
77import pycrdt
88from pycrdt import YMessageType , YSyncMessageType as YSyncMessageSubtype
9+ from jupyter_ydoc import ydocs as jupyter_ydoc_classes
10+ from jupyter_ydoc .ybasedoc import YBaseDoc
911from tornado .websocket import WebSocketHandler
12+ from .yroom_file_api import YRoomFileAPI
1013
1114if TYPE_CHECKING :
1215 from typing import Literal , Tuple
16+ from jupyter_server_fileid .manager import BaseFileIdManager
17+ from jupyter_server .services .contents .manager import AsyncContentsManager , ContentsManager
1318
1419class YRoom :
1520 """A Room to manage all client connection to one notebook file"""
21+
22+ log : Logger
23+ """Log object"""
1624 room_id : str
1725 """Room Id"""
18- ydoc : pycrdt .Doc
26+ _jupyter_ydoc : YBaseDoc
27+ """JupyterYDoc"""
28+ _ydoc : pycrdt .Doc
1929 """Ydoc"""
20- awareness : pycrdt .Awareness
30+ _awareness : pycrdt .Awareness
2131 """Ydoc awareness object"""
22- loop : asyncio .AbstractEventLoop
32+ _loop : asyncio .AbstractEventLoop
2333 """Event loop"""
24- log : Logger
25- """Log object"""
2634 _client_group : YjsClientGroup
2735 """Client group to manage synced and desynced clients"""
2836 _message_queue : asyncio .Queue [Tuple [str , bytes ]]
2937 """A message queue per room to keep websocket messages in order"""
3038
3139
32- def __init__ (self , * , room_id : str , log : Logger , loop : asyncio .AbstractEventLoop ):
40+ def __init__ (
41+ self ,
42+ * ,
43+ room_id : str ,
44+ log : Logger ,
45+ loop : asyncio .AbstractEventLoop ,
46+ fileid_manager : BaseFileIdManager ,
47+ contents_manager : AsyncContentsManager | ContentsManager ,
48+ ):
3349 # Bind instance attributes
3450 self .log = log
35- self .loop = loop
51+ self ._loop = loop
3652 self .room_id = room_id
3753
38- # Initialize YDoc, YAwareness, YjsClientGroup, and message queue
39- self .ydoc = pycrdt .Doc ()
40- self .awareness = pycrdt .Awareness (ydoc = self .ydoc )
41- self .awareness .observe (self .send_server_awareness )
42- self ._client_group = YjsClientGroup (room_id = room_id , log = self .log , loop = self .loop )
43- self ._message_queue = asyncio .Queue ()
54+ # Initialize YjsClientGroup, YDoc, YAwareness, JupyterYDoc
55+ self ._client_group = YjsClientGroup (room_id = room_id , log = self .log , loop = self ._loop )
56+ self ._ydoc = pycrdt .Doc ()
57+ self ._awareness = pycrdt .Awareness (ydoc = self ._ydoc )
58+ JupyterYDocClass = cast (
59+ type [YBaseDoc ],
60+ jupyter_ydoc_classes .get (self .file_type , jupyter_ydoc_classes ["file" ])
61+ )
62+ self .jupyter_ydoc = JupyterYDocClass (ydoc = self ._ydoc , awareness = self ._awareness )
63+
64+ # Initialize YRoomFileAPI and begin loading content
65+ self .file_api = YRoomFileAPI (
66+ room_id = self .room_id ,
67+ jupyter_ydoc = self .jupyter_ydoc ,
68+ log = self .log ,
69+ loop = self ._loop ,
70+ fileid_manager = fileid_manager ,
71+ contents_manager = contents_manager
72+ )
73+ self .file_api .load_ydoc_content ()
4474
45- # Start observer on the `ydoc` to ensure new updates are broadcast to
46- # all clients and saved to disk.
47- self .ydoc .observe (lambda event : self .write_sync_update (event .update ))
75+ # Start observers on `self.ydoc` and `self.awareness` to ensure new
76+ # updates are broadcast to all clients and saved to disk.
77+ self ._awareness .observe (self .send_server_awareness )
78+ self ._ydoc .observe (lambda event : self .write_sync_update (event .update ))
4879
49- # Start background task that routes new messages in the message queue
50- # to the appropriate handler method.
51- self .loop .create_task (self ._on_new_message ())
80+ # Initialize message queue and start background task that routes new
81+ # messages in the message queue to the appropriate handler method.
82+ self ._message_queue = asyncio .Queue ()
83+ self ._loop .create_task (self ._on_new_message ())
5284
5385
5486 @property
@@ -59,21 +91,32 @@ def clients(self) -> YjsClientGroup:
5991 """
6092
6193 return self ._client_group
62-
6394
64- def add_client (self , websocket : WebSocketHandler ) -> str :
95+
96+ async def get_jupyter_ydoc (self ):
6597 """
66- Creates a new client from the given Tornado WebSocketHandler and
67- adds it to the room. Returns the ID of the new client.
98+ Returns a reference to the room's JupyterYDoc
99+ (`jupyter_ydoc.ybasedoc.YBaseDoc`) after waiting for its content to be
100+ loaded from the ContentsManager.
68101 """
102+ await self .file_api .ydoc_content_loaded
103+ return self .jupyter_ydoc
104+
69105
70- return self .clients .add (websocket )
71-
72-
73- def remove_client (self , client_id : str ) -> None :
74- """Removes a client from the room, given the client ID."""
106+ async def get_ydoc (self ):
107+ """
108+ Returns a reference to the room's YDoc (`pycrdt.Doc`) after
109+ waiting for its content to be loaded from the ContentsManager.
110+ """
111+ await self .file_api .ydoc_content_loaded
112+ return self ._ydoc
75113
76- self .clients .remove (client_id )
114+
115+ def get_awareness (self ):
116+ """
117+ Returns a reference to the room's awareness (`pycrdt.Awareness`).
118+ """
119+ return self ._awareness
77120
78121
79122 def add_message (self , client_id : str , message : bytes ) -> None :
@@ -91,6 +134,11 @@ async def _on_new_message(self) -> None:
91134 message type & subtype, which are obtained from the first 2 bytes of the
92135 message.
93136 """
137+ # Wait for content to be loaded before processing any messages in the
138+ # message queue
139+ await self .file_api .ydoc_content_loaded
140+
141+ # Begin processing messages from the message queue
94142 while True :
95143 try :
96144 client_id , message = await self ._message_queue .get ()
@@ -141,7 +189,7 @@ def handle_sync_step1(self, client_id: str, message: bytes) -> None:
141189 # Compute SyncStep2 reply
142190 try :
143191 message_payload = message [1 :]
144- sync_step2_message = pycrdt .handle_sync_message (message_payload , self .ydoc )
192+ sync_step2_message = pycrdt .handle_sync_message (message_payload , self ._ydoc )
145193 assert isinstance (sync_step2_message , bytes )
146194 except Exception as e :
147195 self .log .error (
@@ -170,7 +218,7 @@ def handle_sync_step1(self, client_id: str, message: bytes) -> None:
170218 # Send SyncStep1 message
171219 try :
172220 assert isinstance (new_client .websocket , WebSocketHandler )
173- sync_step1_message = pycrdt .create_sync_message (self .ydoc )
221+ sync_step1_message = pycrdt .create_sync_message (self ._ydoc )
174222 new_client .websocket .write_message (sync_step1_message )
175223 except Exception as e :
176224 self .log .error (
@@ -191,7 +239,7 @@ def handle_sync_step2(self, client_id: str, message: bytes) -> None:
191239 """
192240 try :
193241 message_payload = message [1 :]
194- pycrdt .handle_sync_message (message_payload , self .ydoc )
242+ pycrdt .handle_sync_message (message_payload , self ._ydoc )
195243 except Exception as e :
196244 self .log .error (
197245 "An exception occurred when applying a SyncStep2 message "
@@ -217,7 +265,7 @@ def handle_sync_update(self, client_id: str, message: bytes) -> None:
217265 # Apply the SyncUpdate to the YDoc
218266 try :
219267 message_payload = message [1 :]
220- pycrdt .handle_sync_message (message_payload , self .ydoc )
268+ pycrdt .handle_sync_message (message_payload , self ._ydoc )
221269 except Exception as e :
222270 self .log .error (
223271 "An exception occurred when applying a SyncUpdate message "
@@ -251,7 +299,7 @@ def handle_awareness_update(self, client_id: str, message: bytes) -> None:
251299 # Apply the AwarenessUpdate message
252300 try :
253301 message_payload = message [1 :]
254- self .awareness .apply_awareness_update (message_payload , origin = self )
302+ self ._awareness .apply_awareness_update (message_payload , origin = self )
255303 except Exception as e :
256304 self .log .error (
257305 "An exception occurred when applying an AwarenessUpdate"
@@ -315,7 +363,7 @@ def send_server_awareness(self, type: str, changes: tuple[dict[str, Any], Any])
315363 return
316364
317365 updated_clients = [v for value in changes [0 ].values () for v in value ]
318- state = self .awareness .encode_awareness_update (updated_clients )
366+ state = self ._awareness .encode_awareness_update (updated_clients )
319367 message = pycrdt .create_awareness_message (state )
320368 self ._broadcast_message (message , "AwarenessUpdate" )
321369
0 commit comments