diff --git a/jupyter_server_documents/kernels/kernel_client.py b/jupyter_server_documents/kernels/kernel_client.py index 658087a..9fd05c6 100644 --- a/jupyter_server_documents/kernels/kernel_client.py +++ b/jupyter_server_documents/kernels/kernel_client.py @@ -106,12 +106,11 @@ def handle_incoming_message(self, channel_name: str, msg: list[bytes]): metadata = self.session.unpack(msg[2]) cell_id = metadata.get("cellId") - # Clear output processor if this cell already has - # an existing request. + # Clear cell outputs if cell is re-executedq if cell_id: existing = self.message_cache.get(cell_id=cell_id) if existing and existing['msg_id'] != msg_id: - self.output_processor.clear(cell_id) + asyncio.create_task(self.output_processor.clear_cell_outputs(cell_id)) self.message_cache.add({ "msg_id": msg_id, @@ -218,7 +217,7 @@ async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optiona except Exception as e: self.log.error(f"Error deserializing message: {e}") raise - + parent_msg_id = dmsg["parent_header"]["msg_id"] parent_msg_data = self.message_cache.get(parent_msg_id) cell_id = parent_msg_data.get('cell_id') @@ -273,13 +272,14 @@ async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optiona target_cell["execution_count"] = execution_count break - case "stream" | "display_data" | "execute_result" | "error" | "update_display_data": + case "stream" | "display_data" | "execute_result" | "error" | "update_display_data" | "clear_output": if cell_id: # Process specific output messages through an optional processor if self.output_processor and cell_id: cell_id = parent_msg_data.get('cell_id') content = self.session.unpack(dmsg["content"]) - dmsg = self.output_processor.process_output(dmsg['msg_type'], cell_id, content) + self.output_processor.process_output(dmsg['msg_type'], cell_id, content) + # Suppress forwarding of processed messages by returning None return None diff --git a/jupyter_server_documents/outputs/output_processor.py b/jupyter_server_documents/outputs/output_processor.py index dbd3163..a2886c9 100644 --- a/jupyter_server_documents/outputs/output_processor.py +++ b/jupyter_server_documents/outputs/output_processor.py @@ -2,13 +2,14 @@ from pycrdt import Map -from traitlets import Unicode, Bool +from traitlets import Unicode, Bool, Set from traitlets.config import LoggingConfigurable from jupyter_server_documents.kernels.message_cache import KernelMessageCache class OutputProcessor(LoggingConfigurable): _file_id = Unicode(default_value=None, allow_none=True) + _pending_clear_output_cells = Set(default_value=set()) use_outputs_service = Bool( default_value=True, @@ -46,29 +47,39 @@ def yroom_manager(self): """A shortcut for the jupyter server ydoc manager.""" return self.settings["yroom_manager"] - async def clear_cell_outputs(self, cell_id, room_id): - """Clears the outputs of a cell in ydoc""" + async def get_jupyter_ydoc(self, file_id): + room_id = f"json:notebook:{file_id}" room = self.yroom_manager.get_room(room_id) if room is None: self.log.error(f"YRoom not found: {room_id}") return - notebook = await room.get_jupyter_ydoc() - self.log.info(f"Notebook: {notebook}") - + ydoc = await room.get_jupyter_ydoc() + + return ydoc + + async def _clear_ydoc_outputs(self, cell_id): + """Clears the outputs of a cell in ydoc""" + + if not self._file_id: + return + + notebook = await self.get_jupyter_ydoc(self._file_id) cell_index, target_cell = notebook.find_cell(cell_id) if target_cell is not None: target_cell["outputs"].clear() - self.log.info(f"Cleared outputs for ydoc: {room_id} {cell_index}") + self.log.info(f"Cleared outputs for {self._file_id=}, {cell_index=}") + + async def clear_cell_outputs(self, cell_id): + """Clears all outputs for a cell on disk and in ydoc.""" - def clear(self, cell_id): - """Clear all outputs for a given cell Id.""" if self._file_id is not None: + await self._clear_ydoc_outputs(cell_id) + self._pending_clear_output_cells.discard(cell_id) + if self.use_outputs_service: - room_id = f"json:notebook:{self._file_id}" - asyncio.create_task(self.clear_cell_outputs(cell_id, room_id)) self.outputs_manager.clear(file_id=self._file_id, cell_id=cell_id) + - # Outgoing messages def process_output(self, msg_type: str, cell_id: str, content: dict): """Process outgoing messages from the kernel. @@ -82,11 +93,29 @@ def process_output(self, msg_type: str, cell_id: str, content: dict): The content has not been deserialized yet as we need to verify we should process it. """ - asyncio.create_task(self.output_task(msg_type, cell_id, content)) + if msg_type == "clear_output": + asyncio.create_task(self.clear_output_task(cell_id, content)) + else: + asyncio.create_task(self.output_task(msg_type, cell_id, content)) + return None # Don't allow the original message to propagate to the frontend + async def clear_output_task(self, cell_id, content): + """A courotine to handle clear_output messages""" + + wait = content.get("wait", False) + if wait: + self._pending_clear_output_cells.add(cell_id) + else: + await self.clear_cell_outputs(cell_id) + async def output_task(self, msg_type, cell_id, content): """A coroutine to handle output messages.""" + + # Check for pending clear_output before processing output + if cell_id in self._pending_clear_output_cells: + await self.clear_cell_outputs(cell_id) + try: # TODO: The session manager may have multiple notebooks connected to the kernel # but currently get_session only returns the first. We need to fix this and @@ -117,15 +146,10 @@ async def output_task(self, msg_type, cell_id, content): else: output = self.transform_output(msg_type, content, ydoc=True) - # Get the notebook ydoc from the room - room_id = f"json:notebook:{file_id}" - room = self.yroom_manager.get_room(room_id) - if room is None: - self.log.error(f"YRoom not found: {room_id}") + notebook = await self.get_jupyter_ydoc(file_id) + if not notebook: return - notebook = await room.get_jupyter_ydoc() - self.log.info(f"Notebook: {notebook}") - + # Write the outputs to the ydoc cell. _, target_cell = notebook.find_cell(cell_id) if target_cell is not None and output is not None: