Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 113 additions & 56 deletions jupyter_server_documents/kernels/kernel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,32 @@
from traitlets import Set, Instance, Any, Type, default
from jupyter_client.asynchronous.client import AsyncKernelClient

from .utils import LRUCache
from .message_cache import KernelMessageCache
from jupyter_server_documents.rooms.yroom import YRoom
from jupyter_server_documents.outputs import OutputProcessor
from jupyter_server.utils import ensure_async

from .kernel_client_abc import AbstractDocumentAwareKernelClient

class DocumentAwareKernelClient(AsyncKernelClient):

class DocumentAwareKernelClient(AsyncKernelClient):
"""
A kernel client
A kernel client that routes messages to registered ydocs.
"""
# Having this message cache is not ideal.
# Unfortunately, we don't include the parent channel
# in the messages that generate IOPub status messages, thus,
# we can't differential between the control channel vs.
# shell channel status. This message cache gives us
# the ability to map status message back to their source.
message_source_cache = Instance(
default_value=LRUCache(maxsize=1000), klass=LRUCache
message_cache = Instance(
klass=KernelMessageCache
)

@default('message_cache')
def _default_message_cache(self):
return KernelMessageCache(parent=self)

# A set of callables that are called when a kernel
# message is received.
_listeners = Set(allow_none=True)
Expand All @@ -37,6 +43,7 @@ class DocumentAwareKernelClient(AsyncKernelClient):
# status messages.
_yrooms: t.Set[YRoom] = Set(trait=Instance(YRoom), default_value=set())


output_processor = Instance(
OutputProcessor,
allow_none=True
Expand All @@ -50,7 +57,7 @@ class DocumentAwareKernelClient(AsyncKernelClient):
@default("output_processor")
def _default_output_processor(self) -> OutputProcessor:
self.log.info("Creating output processor")
return OutputProcessor(parent=self, config=self.config)
return self.output_process_class(parent=self, config=self.config)

async def start_listening(self):
"""Start listening to messages coming from the kernel.
Expand Down Expand Up @@ -94,10 +101,23 @@ def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
# Cache the message ID and its socket name so that
# any response message can be mapped back to the
# source channel.
self.output_processor.process_incoming_message(channel=channel_name, msg=msg)
header = json.loads(msg[0]) # TODO: use session.unpack
msg_id = header["msg_id"]
self.message_source_cache[msg_id] = channel_name
header = self.session.unpack(msg[0])
msg_id = header["msg_id"]
metadata = self.session.unpack(msg[2])
cell_id = metadata.get("cellId")

# Clear output processor if this cell already has
# an existing request.
if cell_id:
existing = self.message_cache.get(cell_id=cell_id)
if existing and existing['msg_id'] != msg_id:
self.output_processor.clear(cell_id)

self.message_cache.add({
"msg_id": msg_id,
"channel": channel_name,
"cell_id": cell_id
})
channel = getattr(self, f"{channel_name}_channel")
channel.session.send_raw(channel.socket, msg)

Expand Down Expand Up @@ -152,7 +172,7 @@ async def send_message_to_listeners(self, channel_name: str, msg: list[bytes]):
async with anyio.create_task_group() as tg:
# Broadcast the message to all listeners.
for listener in self._listeners:
async def _wrap_listener(listener_to_wrap, channel_name, msg):
async def _wrap_listener(listener_to_wrap, channel_name, msg):
"""
Wrap the listener to ensure its async and
logs (instead of raises) exceptions.
Expand All @@ -172,63 +192,97 @@ async def handle_outgoing_message(self, channel_name: str, msg: list[bytes]):
when appropriate. Then, it routes the message
to all listeners.
"""
# Intercept messages that are IOPub focused.
if channel_name == "iopub":
message_returned = await self.handle_iopub_message(msg)
# If the message is not returned by the iopub handler, then
# return here and do not forward to listeners.
if not message_returned:
self.log.warn(f"If message is handled do not forward after adding output manager")
if channel_name in ('iopub', 'shell'):
msg = await self.handle_document_related_message(msg)
# If msg has been cleared by the handler, escape this method.
if msg is None:
return

# Update the last activity.
# self.last_activity = self.session.msg_time

await self.send_message_to_listeners(channel_name, msg)

async def handle_iopub_message(self, msg: list[bytes]) -> t.Optional[list[bytes]]:
async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optional[t.List[bytes]]:
"""
Handle messages
Processes document-related messages received from a Jupyter kernel.

Parameters
----------
dmsg: dict
Deserialized message (except concept)

Returns
-------
Returns the message if it should be forwarded to listeners. Otherwise,
returns `None` and prevents (i.e. intercepts) the message from going
to listeners.
"""
Messages are deserialized and handled based on their type. Supported message types
include updating language info, kernel status, execution state, execution count,
and various output types. Some messages may be processed by an output processor
before deciding whether to forward them.

Returns the original message if it is not processed further, otherwise None to indicate
that the message should not be forwarded.
"""
# Begin to deserialize the message safely within a try-except block
try:
dmsg = self.session.deserialize(msg, content=False)
except Exception as e:
self.log.error(f"Error deserializing message: {e}")
raise

if self.output_processor is not None and dmsg["msg_type"] in ("stream", "display_data", "execute_result", "error"):
dmsg = self.output_processor.process_outgoing_message(dmsg)

# If process_outgoing_message returns None, return None so the message isn't
# sent to clients, otherwise return the original serialized message.
if dmsg is None:
return None
else:
return msg

def send_kernel_awareness(self, kernel_status: dict):
"""
Send kernel status awareness messages to all yrooms
"""
for yroom in self._yrooms:
awareness = yroom.get_awareness()
if awareness is None:
self.log.error(f"awareness cannot be None. room_id: {yroom.room_id}")
continue
self.log.debug(f"current state: {awareness.get_local_state()} room_id: {yroom.room_id}. kernel status: {kernel_status}")
awareness.set_local_state_field("kernel", kernel_status)
self.log.debug(f"current state: {awareness.get_local_state()} room_id: {yroom.room_id}")
parent_msg_id = dmsg["parent_header"]["msg_id"]
parent_msg_data = self.message_cache.get(parent_msg_id)
cell_id = parent_msg_data.get('cell_id')

# Handle different message types using pattern matching
match dmsg["msg_type"]:
case "kernel_info_reply":
# Unpack the content to extract language info
content = self.session.unpack(dmsg["content"])
language_info = content["language_info"]
# Update the language info metadata for each collaborative room
for yroom in self._yrooms:
notebook = await yroom.get_jupyter_ydoc()
# The metadata ydoc is not exposed as a
# public property.
metadata = notebook.ymeta
metadata["metadata"]["language_info"] = language_info

case "status":
# Unpack cell-specific information and determine execution state
content = self.session.unpack(dmsg["content"])
execution_state = content.get("execution_state")
# Update status across all collaborative rooms
for yroom in self._yrooms:
# If this status came from the shell channel, update
# the notebook status.
if parent_msg_data["channel"] == "shell":
awareness = yroom.get_awareness()
if awareness is not None:
# Update the kernel execution state at the top document level
awareness.set_local_state_field("kernel", {"execution_state": execution_state})
# Specifically update the running cell's execution state if cell_id is provided
if cell_id:
notebook = await yroom.get_jupyter_ydoc()
_, target_cell = notebook.find_cell(cell_id)
Copy link
Collaborator

@3coins 3coins Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Zsailer
PR 80 also passed cells as a second argument here, is this missing? I can't locate the find_cell API in YNotebook or YBaseDoc to verify this.
https://github.com/jupyter-ai-contrib/jupyter-server-documents/pull/81/files#diff-7c1b5e5cd83f31f24af67c439d6a9422528ddb74a5b20598b25230a798d613d7R252-R253

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this API to the ydocs.py module here. We don't need to pass the list of cells anymore, since that class already has the list as a property.

if target_cell:
# Adjust state naming convention from 'busy' to 'running' as per JupyterLab expectation
# https://github.com/jupyterlab/jupyterlab/blob/0ad84d93be9cb1318d749ffda27fbcd013304d50/packages/cells/src/widget.ts#L1670-L1678
state = 'running' if execution_state == 'busy' else execution_state
target_cell["execution_state"] = state

case "execute_input":
if cell_id:
# Extract execution count and update each collaborative room's notebook
content = self.session.unpack(dmsg["content"])
execution_count = content["execution_count"]
for yroom in self._yrooms:
notebook = await yroom.get_jupyter_ydoc()
_, target_cell = notebook.find_cell(cell_id)
if target_cell:
target_cell["execution_count"] = execution_count

case "stream" | "display_data" | "execute_result" | "error":
if cell_id:
# Process specific output messages through an optional processor
if self.output_processor and cell_id:
cell_id = parent_msg_data.get('cell_id')
content = self.session.unpack(dmsg["content"])
dmsg = self.output_processor.process_output(dmsg['msg_type'], cell_id, content)
# Suppress forwarding of processed messages by returning None
return None

# Default return if message is processed and does not need forwarding
return msg

async def add_yroom(self, yroom: YRoom):
"""
Expand All @@ -242,3 +296,6 @@ async def remove_yroom(self, yroom: YRoom):
De-register a YRoom from handling kernel client messages.
"""
self._yrooms.discard(yroom)


AbstractDocumentAwareKernelClient.register(DocumentAwareKernelClient)
42 changes: 42 additions & 0 deletions jupyter_server_documents/kernels/kernel_client_abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import typing as t
from abc import ABC, abstractmethod

from jupyter_server_documents.rooms.yroom import YRoom


class AbstractKernelClient(ABC):

@abstractmethod
async def start_listening(self):
...

@abstractmethod
async def stop_listening(self):
...

@abstractmethod
def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
...

@abstractmethod
async def handle_outgoing_message(self, channel_name: str, msg: list[bytes]):
...

@abstractmethod
def add_listener(self, callback: t.Callable[[str, list[bytes]], None]):
...

@abstractmethod
def remove_listener(self, callback: t.Callable[[str, list[bytes]], None]):
...


class AbstractDocumentAwareKernelClient(AbstractKernelClient):

@abstractmethod
async def add_yroom(self, yroom: YRoom):
...

@abstractmethod
async def remove_yroom(self, yroom: YRoom):
...
27 changes: 13 additions & 14 deletions jupyter_server_documents/kernels/kernel_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ async def connect(self):
self.set_state(LifecycleStates.UNKNOWN, ExecutionStates.UNKNOWN)
raise Exception("The kernel took too long to connect to the ZMQ sockets.")
# Wait a second until the next time we try again.
await asyncio.sleep(1)
await asyncio.sleep(0.5)
# Wait for the kernel to reach an idle state.
while self.execution_state != ExecutionStates.IDLE.value:
self.main_client.send_kernel_info()
await asyncio.sleep(0.5)
await asyncio.sleep(0.1)

async def disconnect(self):
await self.main_client.stop_listening()
Expand All @@ -139,7 +139,11 @@ async def broadcast_state(self):
session = self.main_client.session
parent_header = session.msg_header("status")
parent_msg_id = parent_header["msg_id"]
self.main_client.message_source_cache[parent_msg_id] = "shell"
self.main_client.message_cache.add({
"msg_id": parent_msg_id,
"channel": "shell",
"cellId": None
})
msg = session.msg("status", content={"execution_state": self.execution_state}, parent=parent_header)
smsg = session.serialize(msg)[1:]
await self.main_client.handle_outgoing_message("iopub", smsg)
Expand All @@ -150,7 +154,7 @@ def execution_state_listener(self, channel_name: str, msg: list[bytes]):
if channel_name != "iopub":
return

session = self.main_client.session
session = self.main_client.session
# Unpack the message
deserialized_msg = session.deserialize(msg, content=False)
if deserialized_msg["msg_type"] == "status":
Expand All @@ -162,14 +166,9 @@ def execution_state_listener(self, channel_name: str, msg: list[bytes]):
else:
parent = deserialized_msg.get("parent_header", {})
msg_id = parent.get("msg_id", "")
parent_channel = self.main_client.message_source_cache.get(msg_id, None)
message_data = self.main_client.message_cache.get(msg_id)
if message_data is None:
return
parent_channel = message_data.get("channel")
if parent_channel and parent_channel == "shell":
self.set_state(LifecycleStates.CONNECTED, ExecutionStates(execution_state))

kernel_status = {
"execution_state": self.execution_state,
"lifecycle_state": self.lifecycle_state
}
self.log.debug(f"Sending kernel status awareness {kernel_status}")
self.main_client.send_kernel_awareness(kernel_status)
self.log.debug(f"Sent kernel status awareness {kernel_status}")
self.set_state(LifecycleStates.CONNECTED, ExecutionStates(execution_state))
Loading
Loading