99from traitlets import Set , Instance , Any , Type , default
1010from jupyter_client .asynchronous .client import AsyncKernelClient
1111
12- from .utils import LRUCache
12+ from .message_cache import KernelMessageCache
1313from jupyter_rtc_core .rooms .yroom import YRoom
1414from jupyter_rtc_core .outputs import OutputProcessor
1515from jupyter_server .utils import ensure_async
1616
17+ from .kernel_client_abc import DocumentAwareKernelClientABC
1718
18- class DocumentAwareKernelClient (AsyncKernelClient ):
19+
20+ class DocumentAwareKernelClient (AsyncKernelClient ):
1921 """
20- A kernel client
22+ A kernel client that routes messages to registered ydocs.
2123 """
2224 # Having this message cache is not ideal.
2325 # Unfortunately, we don't include the parent channel
2426 # in the messages that generate IOPub status messages, thus,
2527 # we can't differential between the control channel vs.
2628 # shell channel status. This message cache gives us
2729 # the ability to map status message back to their source.
28- message_source_cache = Instance (
29- default_value = LRUCache (maxsize = 1000 ), klass = LRUCache
30+ message_cache = Instance (
31+ default_value = KernelMessageCache (maxsize = 10000 ), klass = KernelMessageCache
3032 )
3133
3234 # A set of callables that are called when a kernel
@@ -37,6 +39,7 @@ class DocumentAwareKernelClient(AsyncKernelClient):
3739 # status messages.
3840 _yrooms : t .Set [YRoom ] = Set (trait = Instance (YRoom ), default_value = set ())
3941
42+
4043 output_processor = Instance (
4144 OutputProcessor ,
4245 allow_none = True
@@ -50,7 +53,7 @@ class DocumentAwareKernelClient(AsyncKernelClient):
5053 @default ("output_processor" )
5154 def _default_output_processor (self ) -> OutputProcessor :
5255 self .log .info ("Creating output processor" )
53- return OutputProcessor (parent = self , config = self .config )
56+ return self . output_process_class (parent = self , config = self .config )
5457
5558 async def start_listening (self ):
5659 """Start listening to messages coming from the kernel.
@@ -94,10 +97,23 @@ def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
9497 # Cache the message ID and its socket name so that
9598 # any response message can be mapped back to the
9699 # source channel.
97- self .output_processor .process_incoming_message (channel = channel_name , msg = msg )
98- header = json .loads (msg [0 ]) # TODO: use session.unpack
99- msg_id = header ["msg_id" ]
100- self .message_source_cache [msg_id ] = channel_name
100+ header = self .session .unpack (msg [0 ])
101+ msg_id = header ["msg_id" ]
102+ metadata = self .session .unpack (msg [2 ])
103+ cell_id = metadata .get ("cellId" )
104+
105+ # Clear output processor if this cell already has
106+ # an existing request.
107+ if cell_id :
108+ existing = self .message_cache .get (cell_id = cell_id )
109+ if existing and existing ['msg_id' ] != msg_id :
110+ self .output_processor .clear (cell_id )
111+
112+ self .message_cache .add ({
113+ "msg_id" : msg_id ,
114+ "channel" : channel_name ,
115+ "cell_id" : cell_id
116+ })
101117 channel = getattr (self , f"{ channel_name } _channel" )
102118 channel .session .send_raw (channel .socket , msg )
103119
@@ -152,7 +168,7 @@ async def send_message_to_listeners(self, channel_name: str, msg: list[bytes]):
152168 async with anyio .create_task_group () as tg :
153169 # Broadcast the message to all listeners.
154170 for listener in self ._listeners :
155- async def _wrap_listener (listener_to_wrap , channel_name , msg ):
171+ async def _wrap_listener (listener_to_wrap , channel_name , msg ):
156172 """
157173 Wrap the listener to ensure its async and
158174 logs (instead of raises) exceptions.
@@ -172,63 +188,98 @@ async def handle_outgoing_message(self, channel_name: str, msg: list[bytes]):
172188 when appropriate. Then, it routes the message
173189 to all listeners.
174190 """
175- # Intercept messages that are IOPub focused.
176- if channel_name == "iopub" :
177- message_returned = await self .handle_iopub_message (msg )
178- # If the message is not returned by the iopub handler, then
179- # return here and do not forward to listeners.
180- if not message_returned :
181- self .log .warn (f"If message is handled do not forward after adding output manager" )
191+ if channel_name in ('iopub' , 'shell' ):
192+ msg = await self .handle_document_related_message (msg )
193+ # If msg has been cleared by the handler, escape this method.
194+ if msg is None :
182195 return
183-
184- # Update the last activity.
185- # self.last_activity = self.session.msg_time
196+
186197 await self .send_message_to_listeners (channel_name , msg )
187198
188- async def handle_iopub_message (self , msg : list [bytes ]) -> t .Optional [list [bytes ]]:
199+ async def handle_document_related_message (self , msg : t . List [bytes ]) -> t .Optional [t . List [bytes ]]:
189200 """
190- Handle messages
201+ Processes document-related messages received from a Jupyter kernel.
191202
192- Parameters
193- ----------
194- dmsg: dict
195- Deserialized message (except concept)
196-
197- Returns
198- -------
199- Returns the message if it should be forwarded to listeners. Otherwise,
200- returns `None` and prevents (i.e. intercepts) the message from going
201- to listeners.
202- """
203+ Messages are deserialized and handled based on their type. Supported message types
204+ include updating language info, kernel status, execution state, execution count,
205+ and various output types. Some messages may be processed by an output processor
206+ before deciding whether to forward them.
203207
208+ Returns the original message if it is not processed further, otherwise None to indicate
209+ that the message should not be forwarded.
210+ """
211+ # Begin to deserialize the message safely within a try-except block
204212 try :
205213 dmsg = self .session .deserialize (msg , content = False )
206214 except Exception as e :
207215 self .log .error (f"Error deserializing message: { e } " )
208216 raise
209217
210- if self .output_processor is not None and dmsg ["msg_type" ] in ("stream" , "display_data" , "execute_result" , "error" ):
211- dmsg = self .output_processor .process_outgoing_message (dmsg )
212-
213- # If process_outgoing_message returns None, return None so the message isn't
214- # sent to clients, otherwise return the original serialized message.
215- if dmsg is None :
216- return None
217- else :
218- return msg
219-
220- def send_kernel_awareness (self , kernel_status : dict ):
221- """
222- Send kernel status awareness messages to all yrooms
223- """
224- for yroom in self ._yrooms :
225- awareness = yroom .get_awareness ()
226- if awareness is None :
227- self .log .error (f"awareness cannot be None. room_id: { yroom .room_id } " )
228- continue
229- self .log .debug (f"current state: { awareness .get_local_state ()} room_id: { yroom .room_id } . kernel status: { kernel_status } " )
230- awareness .set_local_state_field ("kernel" , kernel_status )
231- self .log .debug (f"current state: { awareness .get_local_state ()} room_id: { yroom .room_id } " )
218+ parent_msg_id = dmsg ["parent_header" ]["msg_id" ]
219+ parent_msg_data = self .message_cache .get (parent_msg_id )
220+
221+ # Handle different message types using pattern matching
222+ match dmsg ["msg_type" ]:
223+ case "kernel_info_reply" :
224+ # Unpack the content to extract language info
225+ content = self .session .unpack (dmsg ["content" ])
226+ language_info = content ["language_info" ]
227+ # Update the language info metadata for each collaborative room
228+ for yroom in self ._yrooms :
229+ notebook = await yroom .get_jupyter_ydoc ()
230+ # The metadata ydoc is not exposed as a
231+ # public property.
232+ metadata = notebook ._ymeta
233+ metadata ["metadata" ]["language_info" ] = language_info
234+
235+ case "status" :
236+ # Unpack cell-specific information and determine execution state
237+ cell_id = parent_msg_data .get ('cell_id' )
238+ content = self .session .unpack (dmsg ["content" ])
239+ execution_state = content .get ("execution_state" )
240+ # Update status across all collaborative rooms
241+ for yroom in self ._yrooms :
242+ # If this status came from the shell channel, update
243+ # the notebook status.
244+ if parent_msg_data ["channel" ] == "shell" :
245+ awareness = yroom .get_awareness ()
246+ if awareness is not None :
247+ # Update the kernel execution state at the top document level
248+ awareness .set_local_state_field ("kernel" , {"execution_state" : execution_state })
249+ # Specifically update the running cell's execution state if cell_id is provided
250+ if cell_id :
251+ notebook = await yroom .get_jupyter_ydoc ()
252+ cells = notebook .ycells
253+ _ , target_cell = notebook .find_cell (cell_id , cells )
254+ if target_cell :
255+ # Adjust state naming convention from 'busy' to 'running' as per JupyterLab expectation
256+ # https://github.com/jupyterlab/jupyterlab/blob/0ad84d93be9cb1318d749ffda27fbcd013304d50/packages/cells/src/widget.ts#L1670-L1678
257+ state = 'running' if execution_state == 'busy' else execution_state
258+ target_cell ["execution_state" ] = state
259+
260+ case "execute_input" :
261+ # Extract execution count and update each collaborative room's notebook
262+ cell_id = parent_msg_data .get ('cell_id' )
263+ content = self .session .unpack (dmsg ["content" ])
264+ execution_count = content ["execution_count" ]
265+ for yroom in self ._yrooms :
266+ notebook = await yroom .get_jupyter_ydoc ()
267+ cells = notebook .ycells
268+ _ , target_cell = notebook .find_cell (cell_id , cells )
269+ if target_cell :
270+ target_cell ["execution_count" ] = execution_count
271+
272+ case "stream" | "display_data" | "execute_result" | "error" :
273+ # Process specific output messages through an optional processor
274+ if self .output_processor :
275+ cell_id = parent_msg_data .get ('cell_id' )
276+ content = self .session .unpack (dmsg ["content" ])
277+ dmsg = self .output_processor .process_outgoing (dmsg ['msg_type' ], cell_id , content )
278+ # Suppress forwarding of processed messages by returning None
279+ return None
280+
281+ # Default return if message is processed and does not need forwarding
282+ return msg
232283
233284 async def add_yroom (self , yroom : YRoom ):
234285 """
@@ -242,3 +293,6 @@ async def remove_yroom(self, yroom: YRoom):
242293 De-register a YRoom from handling kernel client messages.
243294 """
244295 self ._yrooms .discard (yroom )
296+
297+
298+ DocumentAwareKernelClientABC .register (DocumentAwareKernelClient )
0 commit comments