Skip to content

Commit f77414b

Browse files
committed
Route kernel messages about language_info, execution_state, and execution_count to the ydoc
1 parent d40aaf7 commit f77414b

File tree

10 files changed

+776
-214
lines changed

10 files changed

+776
-214
lines changed

jupyter_rtc_core/kernels/kernel_client.py

Lines changed: 110 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,26 @@
99
from traitlets import Set, Instance, Any, Type, default
1010
from jupyter_client.asynchronous.client import AsyncKernelClient
1111

12-
from .utils import LRUCache
12+
from .message_cache import KernelMessageCache
1313
from jupyter_rtc_core.rooms.yroom import YRoom
1414
from jupyter_rtc_core.outputs import OutputProcessor
1515
from jupyter_server.utils import ensure_async
1616

17+
from .kernel_client_abc import DocumentAwareKernelClientABC
1718

18-
class DocumentAwareKernelClient(AsyncKernelClient):
19+
20+
class DocumentAwareKernelClient(AsyncKernelClient):
1921
"""
20-
A kernel client
22+
A kernel client that routes messages to registered ydocs.
2123
"""
2224
# Having this message cache is not ideal.
2325
# Unfortunately, we don't include the parent channel
2426
# in the messages that generate IOPub status messages, thus,
2527
# we can't differential between the control channel vs.
2628
# shell channel status. This message cache gives us
2729
# the ability to map status message back to their source.
28-
message_source_cache = Instance(
29-
default_value=LRUCache(maxsize=1000), klass=LRUCache
30+
message_cache = Instance(
31+
default_value=KernelMessageCache(maxsize=10000), klass=KernelMessageCache
3032
)
3133

3234
# A set of callables that are called when a kernel
@@ -37,6 +39,7 @@ class DocumentAwareKernelClient(AsyncKernelClient):
3739
# status messages.
3840
_yrooms: t.Set[YRoom] = Set(trait=Instance(YRoom), default_value=set())
3941

42+
4043
output_processor = Instance(
4144
OutputProcessor,
4245
allow_none=True
@@ -50,7 +53,7 @@ class DocumentAwareKernelClient(AsyncKernelClient):
5053
@default("output_processor")
5154
def _default_output_processor(self) -> OutputProcessor:
5255
self.log.info("Creating output processor")
53-
return OutputProcessor(parent=self, config=self.config)
56+
return self.output_process_class(parent=self, config=self.config)
5457

5558
async def start_listening(self):
5659
"""Start listening to messages coming from the kernel.
@@ -94,10 +97,23 @@ def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
9497
# Cache the message ID and its socket name so that
9598
# any response message can be mapped back to the
9699
# source channel.
97-
self.output_processor.process_incoming_message(channel=channel_name, msg=msg)
98-
header = json.loads(msg[0]) # TODO: use session.unpack
99-
msg_id = header["msg_id"]
100-
self.message_source_cache[msg_id] = channel_name
100+
header = self.session.unpack(msg[0])
101+
msg_id = header["msg_id"]
102+
metadata = self.session.unpack(msg[2])
103+
cell_id = metadata.get("cellId")
104+
105+
# Clear output processor if this cell already has
106+
# an existing request.
107+
if cell_id:
108+
existing = self.message_cache.get(cell_id=cell_id)
109+
if existing and existing['msg_id'] != msg_id:
110+
self.output_processor.clear(cell_id)
111+
112+
self.message_cache.add({
113+
"msg_id": msg_id,
114+
"channel": channel_name,
115+
"cell_id": cell_id
116+
})
101117
channel = getattr(self, f"{channel_name}_channel")
102118
channel.session.send_raw(channel.socket, msg)
103119

@@ -152,7 +168,7 @@ async def send_message_to_listeners(self, channel_name: str, msg: list[bytes]):
152168
async with anyio.create_task_group() as tg:
153169
# Broadcast the message to all listeners.
154170
for listener in self._listeners:
155-
async def _wrap_listener(listener_to_wrap, channel_name, msg):
171+
async def _wrap_listener(listener_to_wrap, channel_name, msg):
156172
"""
157173
Wrap the listener to ensure its async and
158174
logs (instead of raises) exceptions.
@@ -172,63 +188,98 @@ async def handle_outgoing_message(self, channel_name: str, msg: list[bytes]):
172188
when appropriate. Then, it routes the message
173189
to all listeners.
174190
"""
175-
# Intercept messages that are IOPub focused.
176-
if channel_name == "iopub":
177-
message_returned = await self.handle_iopub_message(msg)
178-
# If the message is not returned by the iopub handler, then
179-
# return here and do not forward to listeners.
180-
if not message_returned:
181-
self.log.warn(f"If message is handled do not forward after adding output manager")
191+
if channel_name in ('iopub', 'shell'):
192+
msg = await self.handle_document_related_message(msg)
193+
# If msg has been cleared by the handler, escape this method.
194+
if msg is None:
182195
return
183-
184-
# Update the last activity.
185-
# self.last_activity = self.session.msg_time
196+
186197
await self.send_message_to_listeners(channel_name, msg)
187198

188-
async def handle_iopub_message(self, msg: list[bytes]) -> t.Optional[list[bytes]]:
199+
async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optional[t.List[bytes]]:
189200
"""
190-
Handle messages
201+
Processes document-related messages received from a Jupyter kernel.
191202
192-
Parameters
193-
----------
194-
dmsg: dict
195-
Deserialized message (except concept)
196-
197-
Returns
198-
-------
199-
Returns the message if it should be forwarded to listeners. Otherwise,
200-
returns `None` and prevents (i.e. intercepts) the message from going
201-
to listeners.
202-
"""
203+
Messages are deserialized and handled based on their type. Supported message types
204+
include updating language info, kernel status, execution state, execution count,
205+
and various output types. Some messages may be processed by an output processor
206+
before deciding whether to forward them.
203207
208+
Returns the original message if it is not processed further, otherwise None to indicate
209+
that the message should not be forwarded.
210+
"""
211+
# Begin to deserialize the message safely within a try-except block
204212
try:
205213
dmsg = self.session.deserialize(msg, content=False)
206214
except Exception as e:
207215
self.log.error(f"Error deserializing message: {e}")
208216
raise
209217

210-
if self.output_processor is not None and dmsg["msg_type"] in ("stream", "display_data", "execute_result", "error"):
211-
dmsg = self.output_processor.process_outgoing_message(dmsg)
212-
213-
# If process_outgoing_message returns None, return None so the message isn't
214-
# sent to clients, otherwise return the original serialized message.
215-
if dmsg is None:
216-
return None
217-
else:
218-
return msg
219-
220-
def send_kernel_awareness(self, kernel_status: dict):
221-
"""
222-
Send kernel status awareness messages to all yrooms
223-
"""
224-
for yroom in self._yrooms:
225-
awareness = yroom.get_awareness()
226-
if awareness is None:
227-
self.log.error(f"awareness cannot be None. room_id: {yroom.room_id}")
228-
continue
229-
self.log.debug(f"current state: {awareness.get_local_state()} room_id: {yroom.room_id}. kernel status: {kernel_status}")
230-
awareness.set_local_state_field("kernel", kernel_status)
231-
self.log.debug(f"current state: {awareness.get_local_state()} room_id: {yroom.room_id}")
218+
parent_msg_id = dmsg["parent_header"]["msg_id"]
219+
parent_msg_data = self.message_cache.get(parent_msg_id)
220+
221+
# Handle different message types using pattern matching
222+
match dmsg["msg_type"]:
223+
case "kernel_info_reply":
224+
# Unpack the content to extract language info
225+
content = self.session.unpack(dmsg["content"])
226+
language_info = content["language_info"]
227+
# Update the language info metadata for each collaborative room
228+
for yroom in self._yrooms:
229+
notebook = await yroom.get_jupyter_ydoc()
230+
# The metadata ydoc is not exposed as a
231+
# public property.
232+
metadata = notebook._ymeta
233+
metadata["metadata"]["language_info"] = language_info
234+
235+
case "status":
236+
# Unpack cell-specific information and determine execution state
237+
cell_id = parent_msg_data.get('cell_id')
238+
content = self.session.unpack(dmsg["content"])
239+
execution_state = content.get("execution_state")
240+
# Update status across all collaborative rooms
241+
for yroom in self._yrooms:
242+
# If this status came from the shell channel, update
243+
# the notebook status.
244+
if parent_msg_data["channel"] == "shell":
245+
awareness = yroom.get_awareness()
246+
if awareness is not None:
247+
# Update the kernel execution state at the top document level
248+
awareness.set_local_state_field("kernel", {"execution_state": execution_state})
249+
# Specifically update the running cell's execution state if cell_id is provided
250+
if cell_id:
251+
notebook = await yroom.get_jupyter_ydoc()
252+
cells = notebook.ycells
253+
_, target_cell = notebook.find_cell(cell_id, cells)
254+
if target_cell:
255+
# Adjust state naming convention from 'busy' to 'running' as per JupyterLab expectation
256+
# https://github.com/jupyterlab/jupyterlab/blob/0ad84d93be9cb1318d749ffda27fbcd013304d50/packages/cells/src/widget.ts#L1670-L1678
257+
state = 'running' if execution_state == 'busy' else execution_state
258+
target_cell["execution_state"] = state
259+
260+
case "execute_input":
261+
# Extract execution count and update each collaborative room's notebook
262+
cell_id = parent_msg_data.get('cell_id')
263+
content = self.session.unpack(dmsg["content"])
264+
execution_count = content["execution_count"]
265+
for yroom in self._yrooms:
266+
notebook = await yroom.get_jupyter_ydoc()
267+
cells = notebook.ycells
268+
_, target_cell = notebook.find_cell(cell_id, cells)
269+
if target_cell:
270+
target_cell["execution_count"] = execution_count
271+
272+
case "stream" | "display_data" | "execute_result" | "error":
273+
# Process specific output messages through an optional processor
274+
if self.output_processor:
275+
cell_id = parent_msg_data.get('cell_id')
276+
content = self.session.unpack(dmsg["content"])
277+
dmsg = self.output_processor.process_outgoing(dmsg['msg_type'], cell_id, content)
278+
# Suppress forwarding of processed messages by returning None
279+
return None
280+
281+
# Default return if message is processed and does not need forwarding
282+
return msg
232283

233284
async def add_yroom(self, yroom: YRoom):
234285
"""
@@ -242,3 +293,6 @@ async def remove_yroom(self, yroom: YRoom):
242293
De-register a YRoom from handling kernel client messages.
243294
"""
244295
self._yrooms.discard(yroom)
296+
297+
298+
DocumentAwareKernelClientABC.register(DocumentAwareKernelClient)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import typing as t
2+
from abc import ABC, abstractmethod
3+
4+
from jupyter_rtc_core.rooms.yroom import YRoom
5+
6+
7+
class KernelClientABC(ABC):
8+
9+
@abstractmethod
10+
async def start_listening(self):
11+
...
12+
13+
@abstractmethod
14+
async def stop_listening(self):
15+
...
16+
17+
@abstractmethod
18+
def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
19+
...
20+
21+
@abstractmethod
22+
async def handle_outgoing_message(self, channel_name: str, msg: list[bytes]):
23+
...
24+
25+
@abstractmethod
26+
def add_listener(self, callback: t.Callable[[str, list[bytes]], None]):
27+
...
28+
29+
@abstractmethod
30+
def remove_listener(self, callback: t.Callable[[str, list[bytes]], None]):
31+
...
32+
33+
34+
class DocumentAwareKernelClientABC(KernelClientABC):
35+
36+
@abstractmethod
37+
async def add_yroom(self, yroom: YRoom):
38+
...
39+
40+
@abstractmethod
41+
async def remove_yroom(self, yroom: YRoom):
42+
...

jupyter_rtc_core/kernels/kernel_manager.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ async def connect(self):
120120
self.set_state(LifecycleStates.UNKNOWN, ExecutionStates.UNKNOWN)
121121
raise Exception("The kernel took too long to connect to the ZMQ sockets.")
122122
# Wait a second until the next time we try again.
123-
await asyncio.sleep(1)
123+
await asyncio.sleep(.5)
124124
# Wait for the kernel to reach an idle state.
125125
while self.execution_state != ExecutionStates.IDLE.value:
126126
self.main_client.send_kernel_info()
127-
await asyncio.sleep(0.5)
127+
await asyncio.sleep(0.1)
128128

129129
async def disconnect(self):
130130
await self.main_client.stop_listening()
@@ -139,7 +139,11 @@ async def broadcast_state(self):
139139
session = self.main_client.session
140140
parent_header = session.msg_header("status")
141141
parent_msg_id = parent_header["msg_id"]
142-
self.main_client.message_source_cache[parent_msg_id] = "shell"
142+
self.main_client.message_cache.add({
143+
"msg_id": parent_msg_id,
144+
"channel": "shell",
145+
"cellId": None
146+
})
143147
msg = session.msg("status", content={"execution_state": self.execution_state}, parent=parent_header)
144148
smsg = session.serialize(msg)[1:]
145149
await self.main_client.handle_outgoing_message("iopub", smsg)
@@ -162,14 +166,9 @@ def execution_state_listener(self, channel_name: str, msg: list[bytes]):
162166
else:
163167
parent = deserialized_msg.get("parent_header", {})
164168
msg_id = parent.get("msg_id", "")
165-
parent_channel = self.main_client.message_source_cache.get(msg_id, None)
169+
message_data = self.main_client.message_cache.get(msg_id)
170+
if message_data is None:
171+
return
172+
parent_channel = message_data.get("channel")
166173
if parent_channel and parent_channel == "shell":
167-
self.set_state(LifecycleStates.CONNECTED, ExecutionStates(execution_state))
168-
169-
kernel_status = {
170-
"execution_state": self.execution_state,
171-
"lifecycle_state": self.lifecycle_state
172-
}
173-
self.log.debug(f"Sending kernel status awareness {kernel_status}")
174-
self.main_client.send_kernel_awareness(kernel_status)
175-
self.log.debug(f"Sent kernel status awareness {kernel_status}")
174+
self.set_state(LifecycleStates.CONNECTED, ExecutionStates(execution_state))

0 commit comments

Comments
 (0)