diff --git a/jupyter_server_documents/kernels/kernel_client.py b/jupyter_server_documents/kernels/kernel_client.py index 2cc0b9b..57dc94b 100644 --- a/jupyter_server_documents/kernels/kernel_client.py +++ b/jupyter_server_documents/kernels/kernel_client.py @@ -9,15 +9,17 @@ from traitlets import Set, Instance, Any, Type, default from jupyter_client.asynchronous.client import AsyncKernelClient -from .utils import LRUCache +from .message_cache import KernelMessageCache from jupyter_server_documents.rooms.yroom import YRoom from jupyter_server_documents.outputs import OutputProcessor from jupyter_server.utils import ensure_async +from .kernel_client_abc import AbstractDocumentAwareKernelClient -class DocumentAwareKernelClient(AsyncKernelClient): + +class DocumentAwareKernelClient(AsyncKernelClient): """ - A kernel client + A kernel client that routes messages to registered ydocs. """ # Having this message cache is not ideal. # Unfortunately, we don't include the parent channel @@ -25,10 +27,14 @@ class DocumentAwareKernelClient(AsyncKernelClient): # we can't differential between the control channel vs. # shell channel status. This message cache gives us # the ability to map status message back to their source. - message_source_cache = Instance( - default_value=LRUCache(maxsize=1000), klass=LRUCache + message_cache = Instance( + klass=KernelMessageCache ) + @default('message_cache') + def _default_message_cache(self): + return KernelMessageCache(parent=self) + # A set of callables that are called when a kernel # message is received. _listeners = Set(allow_none=True) @@ -37,6 +43,7 @@ class DocumentAwareKernelClient(AsyncKernelClient): # status messages. _yrooms: t.Set[YRoom] = Set(trait=Instance(YRoom), default_value=set()) + output_processor = Instance( OutputProcessor, allow_none=True @@ -50,7 +57,7 @@ class DocumentAwareKernelClient(AsyncKernelClient): @default("output_processor") def _default_output_processor(self) -> OutputProcessor: self.log.info("Creating output processor") - return OutputProcessor(parent=self, config=self.config) + return self.output_process_class(parent=self, config=self.config) async def start_listening(self): """Start listening to messages coming from the kernel. @@ -94,10 +101,23 @@ def handle_incoming_message(self, channel_name: str, msg: list[bytes]): # Cache the message ID and its socket name so that # any response message can be mapped back to the # source channel. - self.output_processor.process_incoming_message(channel=channel_name, msg=msg) - header = json.loads(msg[0]) # TODO: use session.unpack - msg_id = header["msg_id"] - self.message_source_cache[msg_id] = channel_name + header = self.session.unpack(msg[0]) + msg_id = header["msg_id"] + metadata = self.session.unpack(msg[2]) + cell_id = metadata.get("cellId") + + # Clear output processor if this cell already has + # an existing request. + if cell_id: + existing = self.message_cache.get(cell_id=cell_id) + if existing and existing['msg_id'] != msg_id: + self.output_processor.clear(cell_id) + + self.message_cache.add({ + "msg_id": msg_id, + "channel": channel_name, + "cell_id": cell_id + }) channel = getattr(self, f"{channel_name}_channel") channel.session.send_raw(channel.socket, msg) @@ -152,7 +172,7 @@ async def send_message_to_listeners(self, channel_name: str, msg: list[bytes]): async with anyio.create_task_group() as tg: # Broadcast the message to all listeners. for listener in self._listeners: - async def _wrap_listener(listener_to_wrap, channel_name, msg): + async def _wrap_listener(listener_to_wrap, channel_name, msg): """ Wrap the listener to ensure its async and logs (instead of raises) exceptions. @@ -172,63 +192,99 @@ async def handle_outgoing_message(self, channel_name: str, msg: list[bytes]): when appropriate. Then, it routes the message to all listeners. """ - # Intercept messages that are IOPub focused. - if channel_name == "iopub": - message_returned = await self.handle_iopub_message(msg) - # If the message is not returned by the iopub handler, then - # return here and do not forward to listeners. - if not message_returned: - self.log.warn(f"If message is handled do not forward after adding output manager") + if channel_name in ('iopub', 'shell'): + msg = await self.handle_document_related_message(msg) + # If msg has been cleared by the handler, escape this method. + if msg is None: return - - # Update the last activity. - # self.last_activity = self.session.msg_time + await self.send_message_to_listeners(channel_name, msg) - async def handle_iopub_message(self, msg: list[bytes]) -> t.Optional[list[bytes]]: + async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optional[t.List[bytes]]: """ - Handle messages + Processes document-related messages received from a Jupyter kernel. - Parameters - ---------- - dmsg: dict - Deserialized message (except concept) - - Returns - ------- - Returns the message if it should be forwarded to listeners. Otherwise, - returns `None` and prevents (i.e. intercepts) the message from going - to listeners. - """ + Messages are deserialized and handled based on their type. Supported message types + include updating language info, kernel status, execution state, execution count, + and various output types. Some messages may be processed by an output processor + before deciding whether to forward them. + Returns the original message if it is not processed further, otherwise None to indicate + that the message should not be forwarded. + """ + # Begin to deserialize the message safely within a try-except block try: dmsg = self.session.deserialize(msg, content=False) except Exception as e: self.log.error(f"Error deserializing message: {e}") raise - if self.output_processor is not None and dmsg["msg_type"] in ("stream", "display_data", "execute_result", "error"): - dmsg = self.output_processor.process_outgoing_message(dmsg) - - # If process_outgoing_message returns None, return None so the message isn't - # sent to clients, otherwise return the original serialized message. - if dmsg is None: - return None - else: - return msg - - def send_kernel_awareness(self, kernel_status: dict): - """ - Send kernel status awareness messages to all yrooms - """ - for yroom in self._yrooms: - awareness = yroom.get_awareness() - if awareness is None: - self.log.error(f"awareness cannot be None. room_id: {yroom.room_id}") - continue - self.log.debug(f"current state: {awareness.get_local_state()} room_id: {yroom.room_id}. kernel status: {kernel_status}") - awareness.set_local_state_field("kernel", kernel_status) - self.log.debug(f"current state: {awareness.get_local_state()} room_id: {yroom.room_id}") + parent_msg_id = dmsg["parent_header"]["msg_id"] + parent_msg_data = self.message_cache.get(parent_msg_id) + cell_id = parent_msg_data.get('cell_id') + + # Handle different message types using pattern matching + match dmsg["msg_type"]: + case "kernel_info_reply": + # Unpack the content to extract language info + content = self.session.unpack(dmsg["content"]) + language_info = content["language_info"] + # Update the language info metadata for each collaborative room + for yroom in self._yrooms: + notebook = await yroom.get_jupyter_ydoc() + # The metadata ydoc is not exposed as a + # public property. + metadata = notebook.ymeta + metadata["metadata"]["language_info"] = language_info + + case "status": + # Unpack cell-specific information and determine execution state + content = self.session.unpack(dmsg["content"]) + execution_state = content.get("execution_state") + # Update status across all collaborative rooms + for yroom in self._yrooms: + # If this status came from the shell channel, update + # the notebook status. + if parent_msg_data["channel"] == "shell": + awareness = yroom.get_awareness() + if awareness is not None: + # Update the kernel execution state at the top document level + awareness.set_local_state_field("kernel", {"execution_state": execution_state}) + # Specifically update the running cell's execution state if cell_id is provided + if cell_id: + notebook = await yroom.get_jupyter_ydoc() + _, target_cell = notebook.find_cell(cell_id) + if target_cell: + # Adjust state naming convention from 'busy' to 'running' as per JupyterLab expectation + # https://github.com/jupyterlab/jupyterlab/blob/0ad84d93be9cb1318d749ffda27fbcd013304d50/packages/cells/src/widget.ts#L1670-L1678 + state = 'running' if execution_state == 'busy' else execution_state + target_cell["execution_state"] = state + break + + case "execute_input": + if cell_id: + # Extract execution count and update each collaborative room's notebook + content = self.session.unpack(dmsg["content"]) + execution_count = content["execution_count"] + for yroom in self._yrooms: + notebook = await yroom.get_jupyter_ydoc() + _, target_cell = notebook.find_cell(cell_id) + if target_cell: + target_cell["execution_count"] = execution_count + break + + case "stream" | "display_data" | "execute_result" | "error": + if cell_id: + # Process specific output messages through an optional processor + if self.output_processor and cell_id: + cell_id = parent_msg_data.get('cell_id') + content = self.session.unpack(dmsg["content"]) + dmsg = self.output_processor.process_output(dmsg['msg_type'], cell_id, content) + # Suppress forwarding of processed messages by returning None + return None + + # Default return if message is processed and does not need forwarding + return msg async def add_yroom(self, yroom: YRoom): """ @@ -242,3 +298,6 @@ async def remove_yroom(self, yroom: YRoom): De-register a YRoom from handling kernel client messages. """ self._yrooms.discard(yroom) + + +AbstractDocumentAwareKernelClient.register(DocumentAwareKernelClient) diff --git a/jupyter_server_documents/kernels/kernel_client_abc.py b/jupyter_server_documents/kernels/kernel_client_abc.py new file mode 100644 index 0000000..ecb705c --- /dev/null +++ b/jupyter_server_documents/kernels/kernel_client_abc.py @@ -0,0 +1,42 @@ +import typing as t +from abc import ABC, abstractmethod + +from jupyter_server_documents.rooms.yroom import YRoom + + +class AbstractKernelClient(ABC): + + @abstractmethod + async def start_listening(self): + ... + + @abstractmethod + async def stop_listening(self): + ... + + @abstractmethod + def handle_incoming_message(self, channel_name: str, msg: list[bytes]): + ... + + @abstractmethod + async def handle_outgoing_message(self, channel_name: str, msg: list[bytes]): + ... + + @abstractmethod + def add_listener(self, callback: t.Callable[[str, list[bytes]], None]): + ... + + @abstractmethod + def remove_listener(self, callback: t.Callable[[str, list[bytes]], None]): + ... + + +class AbstractDocumentAwareKernelClient(AbstractKernelClient): + + @abstractmethod + async def add_yroom(self, yroom: YRoom): + ... + + @abstractmethod + async def remove_yroom(self, yroom: YRoom): + ... \ No newline at end of file diff --git a/jupyter_server_documents/kernels/kernel_manager.py b/jupyter_server_documents/kernels/kernel_manager.py index 8200ba3..00083da 100644 --- a/jupyter_server_documents/kernels/kernel_manager.py +++ b/jupyter_server_documents/kernels/kernel_manager.py @@ -120,11 +120,11 @@ async def connect(self): self.set_state(LifecycleStates.UNKNOWN, ExecutionStates.UNKNOWN) raise Exception("The kernel took too long to connect to the ZMQ sockets.") # Wait a second until the next time we try again. - await asyncio.sleep(1) + await asyncio.sleep(0.5) # Wait for the kernel to reach an idle state. while self.execution_state != ExecutionStates.IDLE.value: self.main_client.send_kernel_info() - await asyncio.sleep(0.5) + await asyncio.sleep(0.1) async def disconnect(self): await self.main_client.stop_listening() @@ -139,7 +139,11 @@ async def broadcast_state(self): session = self.main_client.session parent_header = session.msg_header("status") parent_msg_id = parent_header["msg_id"] - self.main_client.message_source_cache[parent_msg_id] = "shell" + self.main_client.message_cache.add({ + "msg_id": parent_msg_id, + "channel": "shell", + "cellId": None + }) msg = session.msg("status", content={"execution_state": self.execution_state}, parent=parent_header) smsg = session.serialize(msg)[1:] await self.main_client.handle_outgoing_message("iopub", smsg) @@ -150,7 +154,7 @@ def execution_state_listener(self, channel_name: str, msg: list[bytes]): if channel_name != "iopub": return - session = self.main_client.session + session = self.main_client.session # Unpack the message deserialized_msg = session.deserialize(msg, content=False) if deserialized_msg["msg_type"] == "status": @@ -162,14 +166,9 @@ def execution_state_listener(self, channel_name: str, msg: list[bytes]): else: parent = deserialized_msg.get("parent_header", {}) msg_id = parent.get("msg_id", "") - parent_channel = self.main_client.message_source_cache.get(msg_id, None) + message_data = self.main_client.message_cache.get(msg_id) + if message_data is None: + return + parent_channel = message_data.get("channel") if parent_channel and parent_channel == "shell": - self.set_state(LifecycleStates.CONNECTED, ExecutionStates(execution_state)) - - kernel_status = { - "execution_state": self.execution_state, - "lifecycle_state": self.lifecycle_state - } - self.log.debug(f"Sending kernel status awareness {kernel_status}") - self.main_client.send_kernel_awareness(kernel_status) - self.log.debug(f"Sent kernel status awareness {kernel_status}") \ No newline at end of file + self.set_state(LifecycleStates.CONNECTED, ExecutionStates(execution_state)) \ No newline at end of file diff --git a/jupyter_server_documents/kernels/message_cache.py b/jupyter_server_documents/kernels/message_cache.py new file mode 100644 index 0000000..b31ba0b --- /dev/null +++ b/jupyter_server_documents/kernels/message_cache.py @@ -0,0 +1,226 @@ +import json +from collections import OrderedDict +from traitlets import Dict, Instance, Int +from traitlets.config import LoggingConfigurable + + +class MissingKeyException(Exception): + """An exception when a dictionary is missing a required key.""" + +class InvalidKeyException(Exception): + """An exception when the key doesn't match msg_id property in value""" + + +class KernelMessageCache(LoggingConfigurable): + """ + A cache for storing kernel messages, optimized for access by message ID and cell ID. + + The cache uses an OrderedDict for message IDs to maintain insertion order and + implement LRU eviction. Messages are also indexed by cell ID for faster + retrieval when the cell ID is known. + + Attributes: + _by_cell_id (dict): A dictionary mapping cell IDs to message data. + _by_msg_id (OrderedDict): An OrderedDict mapping message IDs to message data, + maintaining insertion order for LRU eviction. + maxsize (int): The maximum number of messages to store in the cache. + """ + + _by_cell_id = Dict({}) + _by_msg_id = Instance(OrderedDict, default_value=OrderedDict()) + maxsize = Int(default_value=10000).tag(config=True) + + + def __repr__(self): + """ + Returns a JSON string representation of the message ID cache. + """ + return json.dumps(self._by_msg_id, indent=2) + + def __getitem__(self, msg_id): + """ + Retrieves a message from the cache by message ID. Moves the accessed + message to the end of the OrderedDict to update its access time. + + Args: + msg_id (str): The message ID. + + Returns: + dict: The message data. + + Raises: + KeyError: If the message ID is not found in the cache. + """ + out = self._by_msg_id[msg_id] + self._by_msg_id.move_to_end(msg_id) + return out + + def __setitem__(self, msg_id, value): + """ + Adds a message to the cache. If the cache is full, the least recently + used message is evicted. + + Args: + msg_id (str): The message ID. + value (dict): The message data. + + Raises: + Exception: If the msg_id does not match the message ID in the value, + or if the message data is missing required fields + ("msg_id", "channel"). + """ + if "msg_id" not in value: + raise MissingKeyException("`msg_id` missing in message data") + + if "channel" not in value: + raise MissingKeyException("`channel` missing in message data") + + if value["msg_id"] != msg_id: + raise InvalidKeyException("Key must match `msg_id` in value") + + # Remove the existing msg_id if a new msg with same cell_id exists + if value["channel"] == "shell" and "cell_id" in value and value["cell_id"] in self._by_cell_id: + existing_msg_id = self._by_cell_id[value["cell_id"]]["msg_id"] + if msg_id != existing_msg_id: + del self._by_msg_id[existing_msg_id] + + if "cell_id" in value and value['cell_id'] is not None: + self._by_cell_id[value['cell_id']] = value + + self._by_msg_id[msg_id] = value + if len(self._by_msg_id) > self.maxsize: + self._remove_oldest() + + def _remove_oldest(self): + """ + Removes the least recently used message from the cache. + """ + try: + key, item = self._by_msg_id.popitem(last=False) + if 'cell_id' in item: # Check if 'cell_id' key exists + try: + del self._by_cell_id[item['cell_id']] + except KeyError: + pass # Handle the case where the cell_id is not present + except KeyError: + pass # Handle the case where the cache is empty + + def __delitem__(self, msg_id): + """ + Removes a message from the cache by message ID. + + Args: + msg_id (str): The message ID. + """ + msg_data = self._by_msg_id[msg_id] + try: + cell_id = msg_data["cell_id"] + del self._by_cell_id[cell_id] + except KeyError: + pass + del self._by_msg_id[msg_id] + + def __contains__(self, msg_id): + """ + Checks if a message with the given message ID is in the cache. + + Args: + msg_id (str): The message ID. + + Returns: + bool: True if the message is in the cache, False otherwise. + """ + return msg_id in self._by_msg_id + + def __iter__(self): + """ + Returns an iterator over the message IDs in the cache. + """ + for msg_id in self._by_msg_id: + yield msg_id + + def __len__(self): + """ + Returns the number of messages in the cache. + """ + return len(self._by_msg_id) + + def add(self, data): + """ + Adds a message to the cache using its message ID as the key. + + Args: + data (dict): The message data. + """ + self[data['msg_id']] = data + + def get(self, msg_id=None, cell_id=None): + """ + Retrieves a message from the cache, either by message ID or cell ID. + + Args: + msg_id (str, optional): The message ID. Defaults to None. + cell_id (str, optional): The cell ID. Defaults to None. + + Returns: + dict: The message data, or None if not found. + """ + try: + out = self._by_cell_id[cell_id] + msg_id = out['msg_id'] + self._by_msg_id.move_to_end(msg_id) + return out + except KeyError: + try: + out = self._by_msg_id[msg_id] + self._by_msg_id.move_to_end(msg_id) + return out + except KeyError: + return None + + def remove(self, msg_id=None, cell_id=None): + """ + Removes a message from the cache, either by message ID or cell ID. + + Args: + msg_id (str, optional): The message ID. Defaults to None. + cell_id (str, optional): The cell ID. Defaults to None. + """ + try: + out = self._by_cell_id[cell_id] + msg_id = out['msg_id'] + del self._by_msg_id[msg_id] + del self._by_cell_id[cell_id] + except KeyError: + try: + out = self._by_msg_id[msg_id] + try: + cell_id = out['cell_id'] + del self._by_cell_id[cell_id] + except KeyError: + pass + finally: + del self._by_msg_id[msg_id] + except KeyError: + return + + def pop(self, msg_id=None, cell_id=None): + """ + Removes and returns a message from the cache, either by message ID or cell ID. + + Args: + msg_id (str, optional): The message ID. Defaults to None. + cell_id (str, optional): The cell ID. Defaults to None. + + Returns: + dict: The message data. + + Raises: + KeyError: If the message ID or cell ID is not found. + """ + try: + out = self._by_cell_id[cell_id] + except KeyError: + out = self._by_msg_id[msg_id] + self.remove(msg_id=out['msg_id']) + return out \ No newline at end of file diff --git a/jupyter_server_documents/kernels/utils.py b/jupyter_server_documents/kernels/utils.py deleted file mode 100644 index b9c3e62..0000000 --- a/jupyter_server_documents/kernels/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from collections import OrderedDict - - -class LRUCache(OrderedDict): - """Limit size, evicting the least recently looked-up key when full""" - - def __init__(self, maxsize=128, *args, **kwds): - self.maxsize = maxsize - super().__init__(*args, **kwds) - - def __getitem__(self, key): - value = super().__getitem__(key) - self.move_to_end(key) - return value - - def __setitem__(self, key, value): - super().__setitem__(key, value) - if len(self) > self.maxsize: - oldest = next(iter(self)) - del self[oldest] \ No newline at end of file diff --git a/jupyter_server_documents/outputs/output_processor.py b/jupyter_server_documents/outputs/output_processor.py index 4f762e5..94f11cb 100644 --- a/jupyter_server_documents/outputs/output_processor.py +++ b/jupyter_server_documents/outputs/output_processor.py @@ -3,14 +3,12 @@ from pycrdt import Map -from traitlets import Dict, Unicode, Bool +from traitlets import Dict, Unicode, Bool, Instance from traitlets.config import LoggingConfigurable - +from jupyter_server_documents.kernels.message_cache import KernelMessageCache class OutputProcessor(LoggingConfigurable): - - _cell_ids = Dict(default_value={}) # a map from msg_id -> cell_id - _cell_indices = Dict(default_value={}) # a map from cell_id -> cell index in notebook + _file_id = Unicode(default_value=None, allow_none=True) use_outputs_service = Bool( @@ -49,32 +47,6 @@ def yroom_manager(self): """A shortcut for the jupyter server ydoc manager.""" return self.settings["yroom_manager"] - def clear(self, cell_id=None): - """Clear the state of the output processor. - - This clears the state (saved msg_ids, cell_ids, cell indices) for the output - processor. If cell_id is provided, only the state for that cell is cleared. - """ - if cell_id is None: - self._cell_ids = {} - self._cell_indices = {} - else: - msg_id = self.get_msg_id(cell_id) - if (msg_id is not None) and (msg_id in self._cell_ids): del self._cell_ids[msg_id] - if cell_id in self._cell_indices: del self._cell_indices[cell_id] - - def set_cell_id(self, msg_id, cell_id): - """Set the cell_id for a msg_id.""" - self._cell_ids[msg_id] = cell_id - - def get_cell_id(self, msg_id): - """Retrieve a cell_id from a parent msg_id.""" - return self._cell_ids.get(msg_id) - - def get_msg_id(self, cell_id): - """Retrieve a msg_id from a cell_id.""" - return {v: k for k, v in self._cell_ids.items()}.get(cell_id) - async def clear_cell_outputs(self, cell_id, room_id): """Clears the outputs of a cell in ydoc""" room = self.yroom_manager.get_room(room_id) @@ -84,53 +56,21 @@ async def clear_cell_outputs(self, cell_id, room_id): notebook = await room.get_jupyter_ydoc() self.log.info(f"Notebook: {notebook}") - cells = notebook.ycells - cell_index, target_cell = self.find_cell(cell_id, cells) + cell_index, target_cell = notebook.find_cell(cell_id) if target_cell is not None: target_cell["outputs"].clear() self.log.info(f"Cleared outputs for ydoc: {room_id} {cell_index}") - - # Incoming messages - - def process_incoming_message(self, channel: str, msg: list[bytes]): - """Process incoming messages from the frontend. - - Save the cell_id <-> msg_id mapping - - msg = [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] - - This method is used to create a map between cell_id and msg_id. - Incoming execute_request messages have both a cell_id and msg_id. - When output messages are send back to the frontend, this map is used - to find the cell_id for a given parent msg_id. - """ - if channel != "shell": - return - header = json.loads(msg[0]) # TODO use session unpack - msg_type = header.get("msg_type") - if msg_type != "execute_request": - return - msg_id = header.get("msg_id") - metadata = json.loads(msg[2]) # TODO use session unpack - cell_id = metadata.get("cellId") - if cell_id is None: - return # cellId is optional, so this is valid - - existing_msg_id = self.get_msg_id(cell_id) - if existing_msg_id != msg_id: # cell is being re-run, clear output state - self.clear(cell_id) - if self._file_id is not None: - if self.use_outputs_service: - room_id = f"json:notebook:{self._file_id}" - asyncio.create_task(self.clear_cell_outputs(cell_id, room_id)) - self.outputs_manager.clear(file_id=self._file_id, cell_id=cell_id) - self.log.info(f"Saving (msg_id, cell_id): ({msg_id} {cell_id})") - self.set_cell_id(msg_id, cell_id) + def clear(self, cell_id): + """Clear all outputs for a given cell Id.""" + if self._file_id is not None: + if self.use_outputs_service: + room_id = f"json:notebook:{self._file_id}" + asyncio.create_task(self.clear_cell_outputs(cell_id, room_id)) + self.outputs_manager.clear(file_id=self._file_id, cell_id=cell_id) # Outgoing messages - - def process_outgoing_message(self, dmsg: dict): + def process_output(self, msg_type: str, cell_id: str, content: dict): """Process outgoing messages from the kernel. This returns the input dmsg if no the message should be sent to @@ -143,15 +83,6 @@ def process_outgoing_message(self, dmsg: dict): The content has not been deserialized yet as we need to verify we should process it. """ - msg_type = dmsg["header"]["msg_type"] - if msg_type not in ("stream", "display_data", "execute_result", "error"): - return dmsg - msg_id = dmsg["parent_header"]["msg_id"] - content = self.parent.session.unpack(dmsg["content"]) - cell_id = self.get_cell_id(msg_id) - if cell_id is None: - # This is valid as cell_id is optional - return dmsg asyncio.create_task(self.output_task(msg_type, cell_id, content)) return None # Don't allow the original message to propagate to the frontend @@ -192,54 +123,11 @@ async def output_task(self, msg_type, cell_id, content): self.log.info(f"Notebook: {notebook}") # Write the outputs to the ydoc cell. - cells = notebook.ycells - cell_index, target_cell = self.find_cell(cell_id, cells) + _, target_cell = notebook.find_cell(cell_id) if target_cell is not None and output is not None: target_cell["outputs"].append(output) self.log.info(f"Write output to ydoc: {path} {cell_id} {output}") - def find_cell(self, cell_id, cells): - """Find a cell with a given cell_id in the list of cells. - - This uses caching if we have seen the cell previously. - """ - # Find the target_cell and its cell_index and cache - target_cell = None - cell_index = None - try: - # See if we have a cached value for the cell_index - cell_index = self._cell_indices[cell_id] - target_cell = cells[cell_index] - except KeyError: - # Do a linear scan to find the cell - self.log.info(f"Linear scan: {cell_id}") - cell_index, target_cell = self.scan_cells(cell_id, cells) - else: - # Verify that the cached value still matches - if target_cell["id"] != cell_id: - self.log.info(f"Invalid cache hit: {cell_id}") - cell_index, target_cell = self.scan_cells(cell_id, cells) - else: - self.log.info(f"Validated cache hit: {cell_id}") - return cell_index, target_cell - - def scan_cells(self, cell_id, cells): - """Find the cell with a given cell_id in the list of cells. - - This does a simple linear scan of the cells, but in reverse order because - we believe that users are more often running cells towards the end of the - notebook. - """ - target_cell = None - cell_index = None - for i in reversed(range(0, len(cells))): - cell = cells[i] - if cell["id"] == cell_id: - target_cell = cell - cell_index = i - self._cell_indices[cell_id] = cell_index - break - return cell_index, target_cell def transform_output(self, msg_type, content, ydoc=False): """Transform output from IOPub messages to the nbformat specification.""" diff --git a/jupyter_server_documents/rooms/yroom.py b/jupyter_server_documents/rooms/yroom.py index 4c9d93d..6672282 100644 --- a/jupyter_server_documents/rooms/yroom.py +++ b/jupyter_server_documents/rooms/yroom.py @@ -6,7 +6,7 @@ import pycrdt from pycrdt import YMessageType, YSyncMessageType as YSyncMessageSubtype -from jupyter_ydoc import ydocs as jupyter_ydoc_classes +from jupyter_server_documents.ydocs import ydocs as jupyter_ydoc_classes from jupyter_ydoc.ybasedoc import YBaseDoc from tornado.websocket import WebSocketHandler from .yroom_file_api import YRoomFileAPI diff --git a/jupyter_server_documents/tests/test_kernel_message_cache.py b/jupyter_server_documents/tests/test_kernel_message_cache.py new file mode 100644 index 0000000..f89aadc --- /dev/null +++ b/jupyter_server_documents/tests/test_kernel_message_cache.py @@ -0,0 +1,285 @@ +import pytest +from collections import OrderedDict +from jupyter_server_documents.kernels.message_cache import InvalidKeyException, KernelMessageCache, MissingKeyException # Replace your_module + + +def create_cache(maxsize=None): + if maxsize: + cache = KernelMessageCache(maxsize=maxsize) + else: + cache = KernelMessageCache() + + # somehow the same cache is shared in the tests + # clearing the cache so tests pass + cache._by_msg_id.clear() + return cache + + +def create_message(msg_id, channel, cell_id=None, content="test content"): + message = { + "msg_id": msg_id, + "channel": channel, + "content": content, + } + if cell_id: + message["cell_id"] = cell_id + return message + + +def test_setitem_and_getitem(): + cache = create_cache() + message1 = create_message("msg1", "shell") + cache["msg1"] = message1 + assert cache["msg1"] == message1 + + +def test_setitem_key_mismatch(): + cache = create_cache() + message1 = create_message("msg1", "shell") + with pytest.raises(InvalidKeyException, match="Key must match `msg_id` in value"): + cache["wrong_key"] = message1 + + +def test_setitem_missing_msg_id(): + cache = create_cache() + message1 = {"channel": "shell"} # Missing msg_id + with pytest.raises(MissingKeyException, match="`msg_id` missing in message data"): + cache["key"] = message1 + + +def test_setitem_missing_channel(): + cache = create_cache() + message1 = {"msg_id": "msg1"} # Missing channel + with pytest.raises(MissingKeyException, match="`channel` missing in message data"): + cache["msg1"] = message1 + + +def test_setitem_with_cell_id(): + cache = create_cache() + message1 = create_message("msg1", "shell", "cell1") + cache["msg1"] = message1 + assert cache._by_cell_id["cell1"] == message1 + assert "msg1" in cache._by_msg_id + + +def test_setitem_without_cell_id(): + cache = create_cache() + message1 = create_message("msg1", "shell") + cache["msg1"] = message1 + assert "msg1" in cache._by_msg_id + assert not cache._by_cell_id + + +def test_delitem(): + cache = create_cache() + message1 = create_message("msg1", "shell") + cache["msg1"] = message1 + del cache["msg1"] + assert "msg1" not in cache + assert "msg1" not in cache._by_msg_id + + +def test_delitem_with_cell_id(): + cache = create_cache() + message1 = create_message("msg1", "shell", "cell1") + cache["msg1"] = message1 + del cache["msg1"] + assert "msg1" not in cache + assert "cell1" not in cache._by_cell_id + assert "msg1" not in cache._by_msg_id + + +def test_contains(): + cache = create_cache() + message1 = create_message("msg1", "shell") + cache["msg1"] = message1 + assert "msg1" in cache + assert "nonexistent_key" not in cache + + +def test_iter(): + cache = create_cache() + message1 = create_message("msg1", "shell") + message2 = create_message("msg2", "shell") + cache["msg1"] = message1 + cache["msg2"] = message2 + keys = list(cache) + assert "msg1" in keys + assert "msg2" in keys + assert len(keys) == 2 + + +def test_len(): + cache = create_cache() + message1 = create_message("msg1", "shell") + message2 = create_message("msg2", "shell") + cache["msg1"] = message1 + cache["msg2"] = message2 + assert len(cache) == 2 + + +def test_add(): + cache = create_cache() + message1 = create_message("msg1", "shell") + cache.add(message1) + assert cache["msg1"] == message1 + assert "msg1" in cache._by_msg_id + + +def test_get_by_msg_id(): + cache = create_cache() + message1 = create_message("msg1", "shell") + cache["msg1"] = message1 + retrieved_message = cache.get(msg_id="msg1") + assert retrieved_message == message1 + assert isinstance(cache._by_msg_id, OrderedDict) # Check it's still OrderedDict + + +def test_get_by_cell_id(): + cache = create_cache() + message1 = create_message("msg1", "shell", "cell1") + cache["msg1"] = message1 + retrieved_message = cache.get(cell_id="cell1") + assert retrieved_message == message1 + assert isinstance(cache._by_msg_id, OrderedDict) # Check it's still OrderedDict + + +def test_get_not_found(): + cache = create_cache() + retrieved_message = cache.get(msg_id="nonexistent_key") + assert retrieved_message is None + + +def test_remove_by_msg_id(): + cache = create_cache() + message1 = create_message("msg1", "shell") + cache["msg1"] = message1 + cache.remove(msg_id="msg1") + assert "msg1" not in cache + assert "msg1" not in cache._by_msg_id + + +def test_remove_by_cell_id(): + cache = create_cache() + message1 = create_message("msg1", "shell", "cell1") + cache["msg1"] = message1 + cache.remove(cell_id="cell1") + assert "msg1" not in cache + assert "cell1" not in cache._by_cell_id + assert "msg1" not in cache._by_msg_id + + +def test_remove_nonexistent(): + cache = create_cache() + cache.remove(msg_id="nonexistent_key") # Should not raise an error + cache.remove(cell_id="nonexistent_cell") # Should not raise an error + + +def test_pop_by_msg_id(): + cache = create_cache() + message1 = create_message("msg1", "shell") + cache["msg1"] = message1 + popped_message = cache.pop(msg_id="msg1") + assert popped_message == message1 + assert "msg1" not in cache + assert "msg1" not in cache._by_msg_id + + +def test_pop_by_cell_id(): + cache = create_cache() + message1 = create_message("msg1", "shell", "cell1") + cache["msg1"] = message1 + popped_message = cache.pop(cell_id="cell1") + assert popped_message == message1 + assert "msg1" not in cache + assert "cell1" not in cache._by_cell_id + assert "msg1" not in cache._by_msg_id + + +def test_pop_nonexistent(): + cache = create_cache() + with pytest.raises(KeyError): + cache.pop(msg_id="nonexistent_key") + + +def test_repr(): + cache = create_cache() + message1 = create_message("msg1", "shell") + cache["msg1"] = message1 + representation = repr(cache) + assert '"msg1":' in representation + assert '"channel": "shell"' in representation + + +def test_lru_behavior(): + cache = create_cache() + cache._by_msg_id = OrderedDict() # Reset to OrderedDict for LRU test + cache._by_msg_id["msg1"] = create_message("msg1", "shell") + cache._by_msg_id["msg2"] = create_message("msg2", "shell") + cache._by_msg_id["msg3"] = create_message("msg3", "shell") + + # Access "msg1" to make it the most recently used + cache["msg1"] + + # Check the order after accessing + expected_order = ["msg2", "msg3", "msg1"] + assert list(cache._by_msg_id.keys()) == expected_order + + +def test_maxsize_eviction(): + cache = create_cache(2) + message1 = create_message("msg1", "shell") + message2 = create_message("msg2", "shell") + message3 = create_message("msg3", "shell") + + cache["msg1"] = message1 + cache["msg2"] = message2 + cache["msg3"] = message3 # This should evict "msg1" + + assert "msg1" not in cache + assert "msg2" in cache + assert "msg3" in cache + assert len(cache) == 2 + + +def test_remove_oldest(): + cache = create_cache() + cache = KernelMessageCache(maxsize=2) + message1 = create_message("msg1", "shell") + message2 = create_message("msg2", "shell") + message3 = create_message("msg3", "shell") + cache["msg1"] = message1 + cache["msg2"] = message2 + cache["msg3"] = message3 # should trigger remove_oldest + + assert "msg1" not in cache + assert "msg2" in cache + assert "msg3" in cache + +def test_maxsize_eviction_with_cell_id(): + cache = create_cache(2) + message1 = create_message("msg1", "shell", "cell1") + message2 = create_message("msg2", "shell", "cell2") + message3 = create_message("msg3", "shell", "cell3") + + cache["msg1"] = message1 + cache["msg2"] = message2 + cache["msg3"] = message3 # This should evict "msg1" + + assert "msg1" not in cache + assert "msg2" in cache + assert "msg3" in cache + assert "cell1" not in cache._by_cell_id + assert "cell2" in cache._by_cell_id + assert "cell3" in cache._by_cell_id + +def test_existing_msg_id_is_removed(): + cache = create_cache() + message1 = create_message("msg1", "shell", "cell1") + message2 = create_message("msg2", "shell", "cell1") + + cache["msg1"] = message1 + cache["msg2"] = message2 + + assert "msg1" not in cache + assert message2 == cache["msg2"] \ No newline at end of file diff --git a/jupyter_server_documents/ydocs.py b/jupyter_server_documents/ydocs.py new file mode 100644 index 0000000..ae23f24 --- /dev/null +++ b/jupyter_server_documents/ydocs.py @@ -0,0 +1,82 @@ +""" +Extend the YNotebook class with some useful utilities for searching Notebooks. +""" +import sys +# See compatibility note on `group` keyword in +# https://docs.python.org/3/library/importlib.metadata.html#entry-points +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + +from jupyter_ydoc.ynotebook import YNotebook as UpstreamYNotebook + + +class YNotebook(UpstreamYNotebook): + __doc__ = """ + Extends upstream YNotebook to include extra methods. + + Upstream docstring: + """ + UpstreamYNotebook.__doc__ + + _cell_indices: dict # a map from cell_id -> cell index in notebook + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._cell_indices = {} + + @property + def ymeta(self): + """ + Returns the Y-meta. + + :return: The Y-meta. + :rtype: :class:`pycrdt.Map` + """ + return self._ymeta + + def find_cell(self, cell_id): + """Find a cell with a given cell_id in the list of cells. + + This uses caching if we have seen the cell previously. + """ + # Find the target_cell and its cell_index and cache + target_cell = None + cell_index = None + try: + # See if we have a cached value for the cell_index + cell_index = self._cell_indices[cell_id] + target_cell = self.ycells[cell_index] + except KeyError: + # Do a linear scan to find the cell + cell_index, target_cell = self.scan_cells(cell_id) + else: + # Verify that the cached value still matches + if target_cell["id"] != cell_id: + cell_index, target_cell = self.scan_cells(cell_id) + return cell_index, target_cell + + def scan_cells(self, cell_id): + """Find the cell with a given cell_id in the list of cells. + + This does a simple linear scan of the cells, but in reverse order because + we believe that users are more often running cells towards the end of the + notebook. + """ + target_cell = None + cell_index = None + for i in reversed(range(0, len(self.ycells))): + cell = self.ycells[i] + if cell["id"] == cell_id: + target_cell = cell + cell_index = i + self._cell_indices[cell_id] = cell_index + break + return cell_index, target_cell + + + +ydocs = {ep.name: ep.load() for ep in entry_points(group="jupyter_ydoc")} + +# Replace the YNotebook with our local version +ydocs["notebook"] = YNotebook diff --git a/src/notebook.ts b/src/notebook.ts index c75beda..e2211dd 100644 --- a/src/notebook.ts +++ b/src/notebook.ts @@ -1,11 +1,19 @@ -import { CodeCell, CodeCellModel } from '@jupyterlab/cells'; +import { CodeCell, CodeCellModel, ICellModel, ICodeCellModel } from '@jupyterlab/cells'; import { NotebookPanel } from '@jupyterlab/notebook'; import { CellChange, createMutex, ISharedCodeCell } from '@jupyter/ydoc'; import { IOutputAreaModel, OutputAreaModel } from '@jupyterlab/outputarea'; +import { IChangedArgs } from '@jupyterlab/coreutils'; import { requestAPI } from './handler'; const globalModelDBMutex = createMutex(); +/** + * The class name added to the cell when dirty. + */ +const DIRTY_CLASS = 'jp-mod-dirty'; + + + (CodeCellModel.prototype as any)._onSharedModelChanged = function ( slot: ISharedCodeCell, change: CellChange @@ -134,6 +142,54 @@ class RtcOutputAreaModel extends OutputAreaModel implements IOutputAreaModel { } } +/** + * NOTE: We should upstream this fix. This is a bug in JupyterLab. + * + * The execution count comes back from the kernel immediately + * when the execute request is made by the client, even thought + * cell might still be running. JupyterLab holds this value in + * memory with a Promise to set it later, once the execution + * state goes back to Idle. + * + * In CRDT world, we don't need to do this gymnastics, holding + * the state in a Promise. Instead, we can just watch the + * executionState and executionCount in the CRDT being maintained + * by the server-side model. + * + * This is a big win! It means user can close and re-open a + * notebook while a list of executed cells are queued. + */ +(CodeCell.prototype as any).onStateChanged = function ( + + model: ICellModel, + args: IChangedArgs +): void { + switch (args.name) { + case 'executionCount': + // NOTE: This code should not be here. It's a bandaid + // fix because executionState and executionCount + // aren't handled in a single message without CRDT/YNotebook. + // if (args.newValue !== null) { + // // Mark execution state if execution count was set. + // this.model.executionState = 'idle'; + // } + this._updatePrompt(); + break; + case 'executionState': + this._updatePrompt(); + break; + case 'isDirty': + if ((model as ICodeCellModel).isDirty) { + this.addClass(DIRTY_CLASS); + } else { + this.removeClass(DIRTY_CLASS); + } + break; + default: + break; + } +} + CodeCellModel.ContentFactory.prototype.createOutputArea = function ( options: IOutputAreaModel.IOptions ): IOutputAreaModel {