@@ -54,28 +54,23 @@ class YDocWebSocketHandler(WebSocketHandler, JupyterHandler):
54
54
"""
55
55
56
56
_message_queue : asyncio .Queue [Any ]
57
+ _background_tasks : set [asyncio .Task ]
57
58
58
- def initialize (
59
- self ,
60
- ywebsocket_server : JupyterWebsocketServer ,
61
- file_loaders : FileLoaderMapping ,
62
- ystore_class : type [BaseYStore ],
63
- document_cleanup_delay : float | None = 60.0 ,
64
- document_save_delay : float | None = 1.0 ,
65
- ) -> None :
66
- # File ID manager cannot be passed as argument as the extension may load after this one
67
- self ._file_id_manager = self .settings ["file_id_manager" ]
68
- self ._file_loaders = file_loaders
69
- self ._cleanup_delay = document_cleanup_delay
70
- self ._websocket_server = ywebsocket_server
59
+ def create_task (self , aw ):
60
+ task = asyncio .create_task (aw )
61
+ self ._background_tasks .add (task )
62
+ task .add_done_callback (self ._background_tasks .discard )
71
63
72
- self ._message_queue = asyncio .Queue ()
64
+ async def prepare (self ):
65
+ if not self ._websocket_server .started .is_set ():
66
+ self .create_task (self ._websocket_server .start ())
67
+ await self ._websocket_server .started .wait ()
73
68
74
69
# Get room
75
70
self ._room_id : str = self .request .path .split ("/" )[- 1 ]
76
71
77
72
if self ._websocket_server .room_exists (self ._room_id ):
78
- self .room : YRoom = self ._websocket_server .get_room (self ._room_id )
73
+ self .room : YRoom = await self ._websocket_server .get_room (self ._room_id )
79
74
80
75
else :
81
76
if self ._room_id .count (":" ) >= 2 :
@@ -92,7 +87,7 @@ def initialize(
92
87
path = self ._file_id_manager .get_path (file_id )
93
88
path = Path (path )
94
89
updates_file_path = str (path .parent / f".{ file_type } :{ path .name } .y" )
95
- ystore = ystore_class (path = updates_file_path , log = self .log )
90
+ ystore = self . _ystore_class (path = updates_file_path , log = self .log )
96
91
self .room = DocumentRoom (
97
92
self ._room_id ,
98
93
file_format ,
@@ -101,16 +96,37 @@ def initialize(
101
96
self .event_logger ,
102
97
ystore ,
103
98
self .log ,
104
- document_save_delay ,
99
+ self . _document_save_delay ,
105
100
)
106
101
107
102
else :
108
103
# TransientRoom
109
104
# it is a transient document (e.g. awareness)
110
105
self .room = TransientRoom (self ._room_id , self .log )
111
106
107
+ await self ._websocket_server .start_room (self .room )
112
108
self ._websocket_server .add_room (self ._room_id , self .room )
113
109
110
+ return await super ().prepare ()
111
+
112
+ def initialize (
113
+ self ,
114
+ ywebsocket_server : JupyterWebsocketServer ,
115
+ file_loaders : FileLoaderMapping ,
116
+ ystore_class : type [BaseYStore ],
117
+ document_cleanup_delay : float | None = 60.0 ,
118
+ document_save_delay : float | None = 1.0 ,
119
+ ) -> None :
120
+ self ._background_tasks = set ()
121
+ # File ID manager cannot be passed as argument as the extension may load after this one
122
+ self ._file_id_manager = self .settings ["file_id_manager" ]
123
+ self ._file_loaders = file_loaders
124
+ self ._ystore_class = ystore_class
125
+ self ._cleanup_delay = document_cleanup_delay
126
+ self ._document_save_delay = document_save_delay
127
+ self ._websocket_server = ywebsocket_server
128
+ self ._message_queue = asyncio .Queue ()
129
+
114
130
@property
115
131
def path (self ):
116
132
"""
@@ -150,9 +166,7 @@ async def open(self, room_id):
150
166
"""
151
167
On connection open.
152
168
"""
153
- task = asyncio .create_task (self ._websocket_server .serve (self ))
154
- self ._websocket_server .background_tasks .add (task )
155
- task .add_done_callback (self ._websocket_server .background_tasks .discard )
169
+ self .create_task (self ._websocket_server .serve (self ))
156
170
157
171
if isinstance (self .room , DocumentRoom ):
158
172
# Close the connection if the document session expired
0 commit comments