1+ """
2+ A new Kernel client that is aware of ydocuments.
3+ """
14import asyncio
25import json
36import typing as t
710from .utils import LRUCache
811from jupyter_client .asynchronous .client import AsyncKernelClient
912import anyio
13+ from jupyter_rtc_core .rooms .yroom import YRoom
1014
1115
12- class NextGenAsyncKernelClient (AsyncKernelClient ):
16+ class DocumentAwareKernelClient (AsyncKernelClient ):
1317 """
14- A ZMQ-based kernel client class that managers all listeners to a kernel.
18+ A kernel client
1519 """
16- # Having this message cache is not ideal.
20+ # Having this message cache is not ideal.
1721 # Unfortunately, we don't include the parent channel
1822 # in the messages that generate IOPub status messages, thus,
1923 # we can't differential between the control channel vs.
20- # shell channel status. This message cache gives us
24+ # shell channel status. This message cache gives us
2125 # the ability to map status message back to their source.
2226 message_source_cache = Instance (
2327 default_value = LRUCache (maxsize = 1000 ), klass = LRUCache
2428 )
25-
26- # A set of callables that are called when a
27- # ZMQ message comes back from the kernel .
29+
30+ # A set of callables that are called when a kernel
31+ # message is received .
2832 _listeners = Set (allow_none = True )
2933
34+ # A set of YRooms that will intercept output and kernel
35+ # status messages.
36+ _yrooms : t .Set [YRoom ] = Set (trait = Instance (YRoom ), default_value = set ())
37+
38+
3039 async def start_listening (self ):
3140 """Start listening to messages coming from the kernel.
3241
@@ -39,17 +48,18 @@ async def _listening():
3948 tg .start_soon (
4049 self ._listen_for_messages , channel_name
4150 )
42-
51+
4352 # Background this task.
4453 self ._listening_task = asyncio .create_task (_listening ())
4554
46-
4755 async def stop_listening (self ):
56+ """Stop listening to the kernel.
57+ """
4858 # If the listening task isn't defined yet
4959 # do nothing.
5060 if not self ._listening_task :
5161 return
52-
62+
5363 # Attempt to cancel the task.
5464 try :
5565 self ._listening_task .cancel ()
@@ -60,10 +70,10 @@ async def stop_listening(self):
6070 # Log any exceptions that were raised.
6171 except Exception as err :
6272 self .log .error (err )
63-
73+
6474 _listening_task : t .Optional [t .Awaitable ] = Any (allow_none = True )
6575
66- def send_message (self , channel_name , msg ):
76+ def handle_incoming_message (self , channel_name : str , msg : list [ bytes ] ):
6777 """Use the given session to send the message."""
6878 # Cache the message ID and its socket name so that
6979 # any response message can be mapped back to the
@@ -73,32 +83,17 @@ def send_message(self, channel_name, msg):
7383 self .message_source_cache [msg_id ] = channel_name
7484 channel = getattr (self , f"{ channel_name } _channel" )
7585 channel .session .send_raw (channel .socket , msg )
76-
77- async def recv_message (self , channel_name , msg ):
78- """This is the main method that consumes every
79- message coming back from the kernel. It parses the header
80- (not the content, which might be large) and updates
81- the last_activity, execution_state, and lifecycle_state
82- when appropriate. Then, it routes the message
83- to all listeners.
86+
87+ def send_kernel_info (self ):
88+ """Sends a kernel info message on the shell channel. Useful
89+ for determining if the kernel is busy or idle.
8490 """
85- # Broadcast messages
86- async with anyio .create_task_group () as tg :
87- # Broadcast the message to all listeners.
88- for listener in self ._listeners :
89- async def _wrap_listener (listener_to_wrap , channel_name , msg ):
90- """
91- Wrap the listener to ensure its async and
92- logs (instead of raises) exceptions.
93- """
94- try :
95- listener_to_wrap (channel_name , msg )
96- except Exception as err :
97- self .log .error (err )
98-
99- tg .start_soon (_wrap_listener , listener , channel_name , msg )
91+ msg = self .session .msg ("kernel_info_request" )
92+ # Send message, skipping the delimiter and signature
93+ msg = self .session .serialize (msg )[2 :]
94+ self .handle_incoming_message ("shell" , msg )
10095
101- def add_listener (self , callback : t .Callable [[dict ], None ]):
96+ def add_listener (self , callback : t .Callable [[str , list [ bytes ] ], None ]):
10297 """Add a listener to the ZMQ Interface.
10398
10499 A listener is a callable function/method that takes
@@ -108,13 +103,13 @@ def add_listener(self, callback: t.Callable[[dict], None]):
108103 """
109104 self ._listeners .add (callback )
110105
111- def remove_listener (self , callback : t .Callable [[dict ], None ]):
112- """Remove a listener to the ZMQ interface . If the listener
106+ def remove_listener (self , callback : t .Callable [[str , list [ bytes ] ], None ]):
107+ """Remove a listener. If the listener
113108 is not found, this method does nothing.
114109 """
115110 self ._listeners .discard (callback )
116111
117- async def _listen_for_messages (self , channel_name ):
112+ async def _listen_for_messages (self , channel_name : str ):
118113 """The basic polling loop for listened to kernel messages
119114 on a ZMQ socket.
120115 """
@@ -125,16 +120,118 @@ async def _listen_for_messages(self, channel_name):
125120 # Wait for a message
126121 await channel .socket .poll (timeout = float ("inf" ))
127122 raw_msg = await channel .socket .recv_multipart ()
123+ # Drop identities and delimit from the message parts.
124+ _ , fed_msg_list = self .session .feed_identities (raw_msg )
125+ msg = fed_msg_list
128126 try :
129- await self .recv_message (channel_name , raw_msg )
127+ await self .handle_outgoing_message (channel_name , msg )
130128 except Exception as err :
131129 self .log .error (err )
130+
131+ async def send_message_to_listeners (self , channel_name : str , msg : list [bytes ]):
132+ """
133+ Sends message to all registered listeners.
134+ """
135+ async with anyio .create_task_group () as tg :
136+ # Broadcast the message to all listeners.
137+ for listener in self ._listeners :
138+ async def _wrap_listener (listener_to_wrap , channel_name , msg ):
139+ """
140+ Wrap the listener to ensure its async and
141+ logs (instead of raises) exceptions.
142+ """
143+ try :
144+ listener_to_wrap (channel_name , msg )
145+ except Exception as err :
146+ self .log .error (err )
147+
148+ tg .start_soon (_wrap_listener , listener , channel_name , msg )
149+
150+ async def handle_outgoing_message (self , channel_name : str , msg : list [bytes ]):
151+ """This is the main method that consumes every
152+ message coming back from the kernel. It parses the header
153+ (not the content, which might be large) and updates
154+ the last_activity, execution_state, and lifecycsle_state
155+ when appropriate. Then, it routes the message
156+ to all listeners.
157+ """
158+
159+ # Intercept messages that are IOPub focused.
160+ if channel_name == "iopub" :
161+ message_returned = await self .handle_iopub_message (msg )
162+ # TODO: If the message is not returned by the iopub handler, then
163+ # return here and do not forward to listeners.
164+ if not message_returned :
165+ self .log .warn (f"If message is handled donot forward after adding output manager" )
166+ return
167+
168+ # Update the last activity.
169+ #self.last_activity = self.session.msg_time
170+
171+ await self .send_message_to_listeners (channel_name , msg )
172+
173+ async def handle_iopub_message (self , msg : list [bytes ]) -> t .Optional [list [bytes ]]:
174+ """
175+ Handle messages
176+
177+ Parameters
178+ ----------
179+ dmsg: dict
180+ Deserialized message (except concept)
181+
182+ Returns
183+ -------
184+ Returns the message if it should be forwarded to listeners. Otherwise,
185+ returns `None` and keeps (i.e. intercepts) the message from going
186+ to listenres.
187+ """
188+ # NOTE: Here's where we will inject the kernel state
189+ # into the awareness of a document.
190+
191+ try :
192+ dmsg = self .session .deserialize (msg , content = False )
193+ except Exception as e :
194+ self .log .error (f"Error deserializing message: { e } " )
195+ raise ValueError
132196
133- def send_kernel_info (self ):
134- """Sends a kernel info message on the shell channel. Useful
135- for determining if the kernel is busy or idle.
197+ if dmsg ["msg_type" ] == "status" :
198+ # Forward to all yrooms.
199+ for yroom in self ._yrooms :
200+ # NOTE: We need to create a real message here.
201+ awareness_update_message = b""
202+ self .log .debug (f"Update Awareness here: { dmsg } . YRoom: { yroom } " )
203+ #self.log.debug(f"Getting YDoc: {await yroom.get_ydoc()}")
204+ #yroom.add_message(awareness_update_message)
205+
206+ # TODO: returning message temporarily to not break UI
207+ return msg
208+
209+
210+ # NOTE: Inject display data into ydoc.
211+ if dmsg ["msg_type" ] == "display_data" :
212+ # Forward to all yrooms.
213+ for yroom in self ._yrooms :
214+ update_document_message = b""
215+ self .log .debug (f"Update Document here: { dmsg } . Yroom: { yroom } " )
216+ #self.log.debug(f"Getting YDoc: {await yroom.get_ydoc()}")
217+ #yroom.add_message(update_document_message)
218+
219+ # TODO: returning message temporarily to not break UI
220+ return msg
221+
222+ # If the message isn't handled above, return it and it will
223+ # be forwarded to all listeners
224+ return msg
225+
226+ async def add_yroom (self , yroom : YRoom ):
136227 """
137- msg = self .session .msg ("kernel_info_request" )
138- # Send message, skipping the delimiter and signature
139- msg = self .session .serialize (msg )[2 :]
140- self .send_message ("shell" , msg )
228+ Register a YRoom with this kernel client. YRooms will
229+ intercept display and kernel status messages.
230+ """
231+ self ._yrooms .add (yroom )
232+
233+ async def remove_yroom (self , yroom : YRoom ):
234+ """
235+ De-register a YRoom from handling kernel client messages.
236+ """
237+ self ._yrooms .discard (yroom )
0 commit comments