44import uuid
55from typing import Any , Dict , List , Optional , cast
66
7+ from pycrdt import Array , Map
8+
9+ from jupyverse_api .yjs import Yjs
10+
711from .connect import cfg_t , connect_channel , launch_kernel , read_connection_file
812from .connect import write_connection_file as _write_connection_file
913from .kernelspec import find_kernelspec
@@ -23,10 +27,12 @@ def __init__(
2327 connection_file : str = "" ,
2428 write_connection_file : bool = True ,
2529 capture_kernel_output : bool = True ,
30+ yjs : Optional [Yjs ] = None ,
2631 ) -> None :
2732 self .capture_kernel_output = capture_kernel_output
2833 self .kernelspec_path = kernelspec_path or find_kernelspec (kernel_name )
2934 self .kernel_cwd = kernel_cwd
35+ self .yjs = yjs
3036 if not self .kernelspec_path :
3137 raise RuntimeError ("Could not find a kernel, maybe you forgot to install one?" )
3238 if write_connection_file :
@@ -37,11 +43,12 @@ def __init__(
3743 self .key = cast (str , self .connection_cfg ["key" ])
3844 self .session_id = uuid .uuid4 ().hex
3945 self .msg_cnt = 0
40- self .execute_requests : Dict [str , Dict [str , asyncio .Future ]] = {}
41- self .channel_tasks : List [asyncio .Task ] = []
46+ self .execute_requests : Dict [str , Dict [str , asyncio .Queue ]] = {}
47+ self .comm_messages : asyncio .Queue = asyncio .Queue ()
48+ self .tasks : List [asyncio .Task ] = []
4249
4350 async def restart (self , startup_timeout : float = float ("inf" )) -> None :
44- for task in self .channel_tasks :
51+ for task in self .tasks :
4552 task .cancel ()
4653 msg = create_message ("shutdown_request" , content = {"restart" : True })
4754 await send_message (msg , self .control_channel , self .key , change_date_to_str = True )
@@ -52,7 +59,7 @@ async def restart(self, startup_timeout: float = float("inf")) -> None:
5259 if msg ["msg_type" ] == "shutdown_reply" and msg ["content" ]["restart" ]:
5360 break
5461 await self ._wait_for_ready (startup_timeout )
55- self .channel_tasks = []
62+ self .tasks = []
5663 self .listen_channels ()
5764
5865 async def start (self , startup_timeout : float = float ("inf" ), connect : bool = True ) -> None :
@@ -69,6 +76,7 @@ async def connect(self, startup_timeout: float = float("inf")) -> None:
6976 self .connect_channels ()
7077 await self ._wait_for_ready (startup_timeout )
7178 self .listen_channels ()
79+ self .tasks .append (asyncio .create_task (self ._handle_comms ()))
7280
7381 def connect_channels (self , connection_cfg : Optional [cfg_t ] = None ):
7482 connection_cfg = connection_cfg or self .connection_cfg
@@ -77,40 +85,43 @@ def connect_channels(self, connection_cfg: Optional[cfg_t] = None):
7785 self .iopub_channel = connect_channel ("iopub" , connection_cfg )
7886
7987 def listen_channels (self ):
80- self .channel_tasks .append (asyncio .create_task (self .listen_iopub ()))
81- self .channel_tasks .append (asyncio .create_task (self .listen_shell ()))
88+ self .tasks .append (asyncio .create_task (self .listen_iopub ()))
89+ self .tasks .append (asyncio .create_task (self .listen_shell ()))
8290
8391 async def stop (self ) -> None :
8492 self .kernel_process .kill ()
8593 await self .kernel_process .wait ()
8694 os .remove (self .connection_file_path )
87- for task in self .channel_tasks :
95+ for task in self .tasks :
8896 task .cancel ()
8997
9098 async def listen_iopub (self ):
9199 while True :
92100 msg = await receive_message (self .iopub_channel , change_str_to_date = True )
93- msg_id = msg ["parent_header" ].get ("msg_id" )
94- if msg_id in self .execute_requests .keys ():
95- self .execute_requests [msg_id ]["iopub_msg" ].set_result (msg )
101+ parent_id = msg ["parent_header" ].get ("msg_id" )
102+ if msg ["msg_type" ] in ("comm_open" , "comm_msg" ):
103+ self .comm_messages .put_nowait (msg )
104+ elif parent_id in self .execute_requests .keys ():
105+ self .execute_requests [parent_id ]["iopub_msg" ].put_nowait (msg )
96106
97107 async def listen_shell (self ):
98108 while True :
99109 msg = await receive_message (self .shell_channel , change_str_to_date = True )
100110 msg_id = msg ["parent_header" ].get ("msg_id" )
101111 if msg_id in self .execute_requests .keys ():
102- self .execute_requests [msg_id ]["shell_msg" ].set_result (msg )
112+ self .execute_requests [msg_id ]["shell_msg" ].put_nowait (msg )
103113
104114 async def execute (
105115 self ,
106- cell : Dict [ str , Any ] ,
116+ ycell : Map ,
107117 timeout : float = float ("inf" ),
108118 msg_id : str = "" ,
109119 wait_for_executed : bool = True ,
110120 ) -> None :
111- if cell ["cell_type" ] != "code" :
121+ if ycell ["cell_type" ] != "code" :
112122 return
113- content = {"code" : cell ["source" ], "silent" : False }
123+ ycell ["execution_state" ] = "busy"
124+ content = {"code" : str (ycell ["source" ]), "silent" : False }
114125 msg = create_message (
115126 "execute_request" , content , session_id = self .session_id , msg_id = str (self .msg_cnt )
116127 )
@@ -120,40 +131,68 @@ async def execute(
120131 msg_id = msg ["header" ]["msg_id" ]
121132 self .msg_cnt += 1
122133 await send_message (msg , self .shell_channel , self .key , change_date_to_str = True )
134+ self .execute_requests [msg_id ] = {
135+ "iopub_msg" : asyncio .Queue (),
136+ "shell_msg" : asyncio .Queue (),
137+ }
123138 if wait_for_executed :
124139 deadline = time .time () + timeout
125- self .execute_requests [msg_id ] = {
126- "iopub_msg" : asyncio .Future (),
127- "shell_msg" : asyncio .Future (),
128- }
129140 while True :
130141 try :
131- await asyncio .wait_for (
132- self .execute_requests [msg_id ]["iopub_msg" ],
142+ msg = await asyncio .wait_for (
143+ self .execute_requests [msg_id ]["iopub_msg" ]. get () ,
133144 deadline_to_timeout (deadline ),
134145 )
135146 except asyncio .TimeoutError :
136147 error_message = f"Kernel didn't respond in { timeout } seconds"
137148 raise RuntimeError (error_message )
138- msg = self .execute_requests [msg_id ]["iopub_msg" ].result ()
139- self ._handle_outputs (cell ["outputs" ], msg )
149+ await self ._handle_outputs (ycell ["outputs" ], msg )
140150 if (
141- msg ["header" ]["msg_type" ] == "status"
142- and msg ["content" ]["execution_state" ] == "idle"
151+ ( msg ["header" ]["msg_type" ] == "status"
152+ and msg ["content" ]["execution_state" ] == "idle" )
143153 ):
144154 break
145- self .execute_requests [msg_id ]["iopub_msg" ] = asyncio .Future ()
146155 try :
147- await asyncio .wait_for (
148- self .execute_requests [msg_id ]["shell_msg" ],
156+ msg = await asyncio .wait_for (
157+ self .execute_requests [msg_id ]["shell_msg" ]. get () ,
149158 deadline_to_timeout (deadline ),
150159 )
151160 except asyncio .TimeoutError :
152161 error_message = f"Kernel didn't respond in { timeout } seconds"
153162 raise RuntimeError (error_message )
154- msg = self .execute_requests [msg_id ]["shell_msg" ].result ()
155- cell ["execution_count" ] = msg ["content" ]["execution_count" ]
163+ with ycell .doc .transaction ():
164+ ycell ["execution_count" ] = msg ["content" ]["execution_count" ]
165+ ycell ["execution_state" ] = "idle"
156166 del self .execute_requests [msg_id ]
167+ else :
168+ self .tasks .append (asyncio .create_task (self ._handle_iopub (msg_id , ycell )))
169+
170+ async def _handle_iopub (self , msg_id : str , ycell : Map ) -> None :
171+ while True :
172+ msg = await self .execute_requests [msg_id ]["iopub_msg" ].get ()
173+ await self ._handle_outputs (ycell ["outputs" ], msg )
174+ if (
175+ (msg ["header" ]["msg_type" ] == "status"
176+ and msg ["content" ]["execution_state" ] == "idle" )
177+ ):
178+ msg = await self .execute_requests [msg_id ]["shell_msg" ].get ()
179+ with ycell .doc .transaction ():
180+ ycell ["execution_count" ] = msg ["content" ]["execution_count" ]
181+ ycell ["execution_state" ] = "idle"
182+
183+ async def _handle_comms (self ) -> None :
184+ if self .yjs is None :
185+ return
186+
187+ while True :
188+ msg = await self .comm_messages .get ()
189+ msg_type = msg ["header" ]["msg_type" ]
190+ if msg_type == "comm_open" :
191+ comm_id = msg ["content" ]["comm_id" ]
192+ comm = Comm (comm_id , self .shell_channel , self .session_id , self .key )
193+ self .yjs .widgets .comm_open (msg , comm ) # type: ignore
194+ elif msg_type == "comm_msg" :
195+ self .yjs .widgets .comm_msg (msg ) # type: ignore
157196
158197 async def _wait_for_ready (self , timeout ):
159198 deadline = time .time () + timeout
@@ -178,22 +217,51 @@ async def _wait_for_ready(self, timeout):
178217 break
179218 new_timeout = deadline_to_timeout (deadline )
180219
181- def _handle_outputs (self , outputs : List [ Dict [ str , Any ]] , msg : Dict [str , Any ]):
220+ async def _handle_outputs (self , outputs : Array , msg : Dict [str , Any ]):
182221 msg_type = msg ["header" ]["msg_type" ]
183222 content = msg ["content" ]
184223 if msg_type == "stream" :
185- if (not outputs ) or (outputs [- 1 ]["name" ] != content ["name" ]):
186- outputs .append ({"name" : content ["name" ], "output_type" : msg_type , "text" : []})
187- outputs [- 1 ]["text" ].append (content ["text" ])
224+ with outputs .doc .transaction ():
225+ # TODO: uncomment when changes are made in jupyter-ydoc
226+ if (not outputs ) or (outputs [- 1 ]["name" ] != content ["name" ]): # type: ignore
227+ outputs .append (
228+ #Map(
229+ # {
230+ # "name": content["name"],
231+ # "output_type": msg_type,
232+ # "text": Array([content["text"]]),
233+ # }
234+ #)
235+ {
236+ "name" : content ["name" ],
237+ "output_type" : msg_type ,
238+ "text" : [content ["text" ]],
239+ }
240+ )
241+ else :
242+ #outputs[-1]["text"].append(content["text"]) # type: ignore
243+ last_output = outputs [- 1 ]
244+ last_output ["text" ].append (content ["text" ]) # type: ignore
245+ outputs [- 1 ] = last_output
188246 elif msg_type in ("display_data" , "execute_result" ):
189- outputs .append (
190- {
191- "data" : {"text/plain" : [content ["data" ].get ("text/plain" , "" )]},
192- "execution_count" : content ["execution_count" ],
193- "metadata" : {},
194- "output_type" : msg_type ,
195- }
196- )
247+ if "application/vnd.jupyter.ywidget-view+json" in content ["data" ]:
248+ # this is a collaborative widget
249+ model_id = content ["data" ]["application/vnd.jupyter.ywidget-view+json" ]["model_id" ]
250+ if self .yjs is not None :
251+ if model_id in self .yjs .widgets .widgets : # type: ignore
252+ doc = self .yjs .widgets .widgets [model_id ]["model" ].ydoc # type: ignore
253+ path = f"ywidget:{ doc .guid } "
254+ await self .yjs .room_manager .websocket_server .get_room (path , ydoc = doc ) # type: ignore
255+ outputs .append (doc )
256+ else :
257+ outputs .append (
258+ {
259+ "data" : {"text/plain" : [content ["data" ].get ("text/plain" , "" )]},
260+ "execution_count" : content ["execution_count" ],
261+ "metadata" : {},
262+ "output_type" : msg_type ,
263+ }
264+ )
197265 elif msg_type == "error" :
198266 outputs .append (
199267 {
@@ -203,5 +271,25 @@ def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
203271 "traceback" : content ["traceback" ],
204272 }
205273 )
206- else :
207- return
274+
275+
276+ class Comm :
277+ def __init__ (self , comm_id : str , shell_channel , session_id : str , key : str ):
278+ self .comm_id = comm_id
279+ self .shell_channel = shell_channel
280+ self .session_id = session_id
281+ self .key = key
282+ self .msg_cnt = 0
283+
284+ def send (self , buffers ):
285+ msg = create_message (
286+ "comm_msg" ,
287+ content = {"comm_id" : self .comm_id },
288+ session_id = self .session_id ,
289+ msg_id = self .msg_cnt ,
290+ buffers = buffers ,
291+ )
292+ self .msg_cnt += 1
293+ asyncio .create_task (
294+ send_message (msg , self .shell_channel , self .key , change_date_to_str = True )
295+ )
0 commit comments