diff --git a/jupyter_rtc_core/__init__.py b/jupyter_rtc_core/__init__.py index 0f7737d..9cb3923 100644 --- a/jupyter_rtc_core/__init__.py +++ b/jupyter_rtc_core/__init__.py @@ -7,7 +7,13 @@ import warnings warnings.warn("Importing 'jupyter_rtc_core' outside a proper installation.") __version__ = "dev" + +from traitlets.config import Config + from .handlers import setup_handlers +from .outputs.connection import RTCWebsocketConnection +from .outputs.handlers import setup_handlers as setup_output_handlers +from .outputs.manager import OutputsManager def _jupyter_labextension_paths(): @@ -23,6 +29,14 @@ def _jupyter_server_extension_points(): }] +def _link_jupyter_server_extension(server_app): + """Setup custom config needed by this extension.""" + server_app.kernel_websocket_connection_class = RTCWebsocketConnection + c = Config() + c.ServerApp.kernel_websocket_connection_class = "jupyter_rtc_core.outputs.connection.RTCWebsocketConnection" + server_app.update_config(c) + + def _load_jupyter_server_extension(server_app): """Registers the API handler to receive HTTP requests from the frontend extension. @@ -32,5 +46,9 @@ def _load_jupyter_server_extension(server_app): JupyterLab application instance """ setup_handlers(server_app.web_app) + setup_output_handlers(server_app.web_app) + om = OutputsManager(config=server_app.config) + server_app.web_app.settings["outputs_manager"] = om + name = "jupyter_rtc_core" server_app.log.info(f"Registered {name} server extension") diff --git a/jupyter_rtc_core/outputs/__init__.py b/jupyter_rtc_core/outputs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jupyter_rtc_core/outputs/connection.py b/jupyter_rtc_core/outputs/connection.py new file mode 100644 index 0000000..bd4f713 --- /dev/null +++ b/jupyter_rtc_core/outputs/connection.py @@ -0,0 +1,298 @@ +import asyncio +import json +import typing as t + +from pycrdt import Map + +from traitlets import Dict, Unicode + +from jupyter_server.services.kernels.connection.channels import ZMQChannelsWebsocketConnection +from jupyter_server.services.kernels.connection.base import ( + deserialize_binary_message, + deserialize_msg_from_ws_v1 +) + +class RTCWebsocketConnection(ZMQChannelsWebsocketConnection): + + _cell_ids = Dict(default_value={}) + _msg_ids = Dict(default_value={}) + _cell_indices = Dict(default_value={}) + file_id = Unicode(allow_none=True, default_value=None) + + @property + def outputs_manager(self): + return self.websocket_handler.settings["outputs_manager"] + + @property + def session_manager(self): + return self.websocket_handler.settings["session_manager"] + + @property + def file_id_manager(self): + return self.websocket_handler.settings["file_id_manager"] + + @property + def jupyter_server_ydoc(self): + return self.websocket_handler.settings["jupyter_server_ydoc"] + + def get_part(self, field, value, msg_list): + """Get a part of a message.""" + if value is None: + field2idx = { + "header": 0, + "parent_header": 1, + "metadata": 2, + "content": 3, + } + value = self.session.unpack(msg_list[field2idx[field]]) + return value + + def save_cell_id(self, channel, msg, msg_list): + """Save the cell_id <-> msg_id map. + + This method is used to create a map between cell_id and msg_id. + Incoming execute_request messages have both a cell_id and msg_id. + When output messages are send back to the frontend, this map is used + to find the cell_id for a given parent msg_id. + """ + if channel != "shell": + return + header = self.get_part("header", msg.get("header"), msg_list) + if header is None: + return + if header["msg_type"] != "execute_request": + return + msg_id = header["msg_id"] + md = self.get_part("metadata", msg.get("metadata"), msg_list) + if md is None: + return + cell_id = md.get('cellId') + existing_msg_id = self.get_msg_id(cell_id) + + if existing_msg_id != msg_id: + if self.file_id is not None: + self.log.info(f"Cell has been rerun, removing old outputs: {self.file_id} {cell_id}") + self.outputs_manager.clear(file_id=self.file_id, cell_id=cell_id) + self.log.info(f"Saving (msg_id, cell_id): ({msg_id} {cell_id})") + self._cell_ids[msg_id] = cell_id + self._msg_ids[cell_id] = msg_id + + def get_cell_id(self, msg_id): + """Retrieve a cell_id from a parent msg_id.""" + return self._cell_ids.get(msg_id) + + def get_msg_id(self, cell_id): + return self._msg_ids.get(cell_id) + + def disconnect(self): + if self.file_id is not None: + self.log.info(f"Removing server outputs: {self.file_id}") + self.outputs_manager.clear(file_id=self.file_id) + super().disconnect() + + def handle_incoming_message(self, incoming_msg: str) -> None: + """Handle incoming messages from Websocket to ZMQ Sockets.""" + ws_msg = incoming_msg + if not self.channels: + # already closed, ignore the message + self.log.debug("Received message on closed websocket %r", ws_msg) + return + + if self.subprotocol == "v1.kernel.websocket.jupyter.org": + channel, msg_list = deserialize_msg_from_ws_v1(ws_msg) + msg = { + "header": None, + } + else: + if isinstance(ws_msg, bytes): # type:ignore[unreachable] + msg = deserialize_binary_message(ws_msg) # type:ignore[unreachable] + else: + msg = json.loads(ws_msg) + msg_list = [] + channel = msg.pop("channel", None) + + if channel is None: + self.log.warning("No channel specified, assuming shell: %s", msg) + channel = "shell" + if channel not in self.channels: + self.log.warning("No such channel: %r", channel) + return + am = self.multi_kernel_manager.allowed_message_types + ignore_msg = False + if am: + msg["header"] = self.get_part("header", msg["header"], msg_list) + assert msg["header"] is not None + if msg["header"]["msg_type"] not in am: # type:ignore[unreachable] + self.log.warning( + 'Received message of type "%s", which is not allowed. Ignoring.' + % msg["header"]["msg_type"] + ) + ignore_msg = True + + # Persist the map between cell_id and msg_id + self.save_cell_id(channel, msg, msg_list) + + if not ignore_msg: + stream = self.channels[channel] + if self.subprotocol == "v1.kernel.websocket.jupyter.org": + self.session.send_raw(stream, msg_list) + else: + self.session.send(stream, msg) + + + async def handle_outgoing_message(self, stream: str, outgoing_msg: list[t.Any]) -> None: + """Handle the outgoing messages from ZMQ sockets to Websocket.""" + msg_list = outgoing_msg + _, fed_msg_list = self.session.feed_identities(msg_list) + + if self.subprotocol == "v1.kernel.websocket.jupyter.org": + msg = {"header": None, "parent_header": None, "content": None} + else: + msg = self.session.deserialize(fed_msg_list) + + if isinstance(stream, str): + stream = self.channels[stream] + + channel = getattr(stream, "channel", None) + parts = fed_msg_list[1:] + + self._on_error(channel, msg, parts) + + # Handle output messages + header = self.get_part("header", msg.get("header"), parts) + msg_type = header["msg_type"] + if msg_type in ("stream", "display_data", "execute_result", "error"): + self.handle_output(msg_type, msg, parts) + return + + # We can probably get rid of the rate limiting + if self._limit_rate(channel, msg, parts): + return + + if self.subprotocol == "v1.kernel.websocket.jupyter.org": + self._on_zmq_reply(stream, parts) + else: + self._on_zmq_reply(stream, msg) + + def handle_output(self, msg_type, msg, parts): + """Handle output messages by writing them to the server side Ydoc.""" + parent_header = self.get_part("parent_header", msg.get("parent_header"), parts) + msg_id = parent_header["msg_id"] + self.log.info(f"handle_output: {msg_id}") + cell_id = self.get_cell_id(msg_id) + if cell_id is None: + return + self.log.info(f"Retreiving (msg_id, cell_id): ({msg_id} {cell_id})") + content = self.get_part("content", msg.get("content"), parts) + self.log.info(f"{cell_id} {msg_type} {content}") + asyncio.create_task(self.output_task(msg_type, cell_id, content)) + + async def output_task(self, msg_type, cell_id, content): + """A coroutine to handle output messages.""" + kernel_session_manager = self.session_manager + try: + kernel_session = await kernel_session_manager.get_session(kernel_id=self.kernel_id) + except: + pass + path = kernel_session["path"] + + file_id = self.file_id_manager.get_id(path) + if file_id is None: + return + self.file_id = file_id + try: + notebook = await self.jupyter_server_ydoc.get_document(path=path, copy=False, file_format='json', content_type='notebook') + except: + return + cells = notebook.ycells + + cell_index, target_cell = self.find_cell(cell_id, cells) + if target_cell is None: + return + + # Convert from the message spec to the nbformat output structure + output = self.transform_output(msg_type, content, ydoc=False) + output_url = self.outputs_manager.write(file_id, cell_id, output) + nb_output = Map({ + "output_type": "display_data", + "data": { + 'text/html': f'Output' + }, + "metadata": { + "outputs_service": True + } + }) + target_cell["outputs"].append(nb_output) + + def find_cell(self, cell_id, cells): + # Find the target_cell and its cell_index and cache + target_cell = None + cell_index = None + try: + # See if we have a cached value for the cell_index + cell_index = self._cell_indices[cell_id] + target_cell = cells[cell_index] + except KeyError: + # Do a linear scan to find the cell + self.log.info(f"Linear scan: {cell_id}") + cell_index, target_cell = self.scan_cells(cell_id, cells) + else: + # Verify that the cached value still matches + if target_cell["id"] != cell_id: + self.log.info(f"Invalid cache hit: {cell_id}") + cell_index, target_cell = self.scan_cells(cell_id, cells) + else: + self.log.info(f"Validated cache hit: {cell_id}") + return cell_index, target_cell + + def scan_cells(self, cell_id, cells): + """Find the cell with a given cell_id. + + This does a simple linear scan of the cells, but in reverse order because + we believe that users are more often running cells towards the end of the + notebook. + """ + target_cell = None + cell_index = None + for i in reversed(range(0, len(cells))): + cell = cells[i] + if cell["id"] == cell_id: + target_cell = cell + cell_index = i + self._cell_indices[cell_id] = cell_index + break + return cell_index, target_cell + + def transform_output(self, msg_type, content, ydoc=False): + """Transform output from IOPub messages to the nbformat specification.""" + if ydoc: + factory = Map + else: + factory = lambda x: x + if msg_type == "stream": + output = factory({ + "output_type": "stream", + "text": content["text"], + "name": content["name"] + }) + elif msg_type == "display_data": + output = factory({ + "output_type": "display_data", + "data": content["data"], + "metadata": content["metadata"] + }) + elif msg_type == "execute_result": + output = factory({ + "output_type": "execute_result", + "data": content["data"], + "metadata": content["metadata"], + "execution_count": content["execution_count"] + }) + elif msg_type == "error": + output = factory({ + "output_type": "error", + "traceback": content["traceback"], + "ename": content["ename"], + "evalue": content["evalue"] + }) + return output diff --git a/jupyter_rtc_core/outputs/handlers.py b/jupyter_rtc_core/outputs/handlers.py new file mode 100644 index 0000000..8e37f9a --- /dev/null +++ b/jupyter_rtc_core/outputs/handlers.py @@ -0,0 +1,99 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import json + +from tornado import web + +from jupyter_server.auth.decorator import authorized +from jupyter_server.base.handlers import APIHandler +from jupyter_server.utils import url_path_join + + +class OutputsAPIHandler(APIHandler): + """An outputs service API handler.""" + + auth_resource = "outputs" + + @property + def outputs(self): + return self.settings["outputs_manager"] + + @web.authenticated + @authorized + async def get(self, file_id=None, cell_id=None, output_index=None): + try: + output = self.outputs.get(file_id, cell_id, output_index) + except (FileNotFoundError, KeyError): + self.set_status(404) + self.finish({"error": "Output not found."}) + else: + self.set_status(200) + self.set_header("Content-Type", "application/json") + self.write(output) + + +class StreamAPIHandler(APIHandler): + """An outputs service API handler.""" + + auth_resource = "outputs" + + @property + def outputs(self): + return self.settings["outputs_manager"] + + @web.authenticated + @authorized + async def get(self, file_id=None, cell_id=None): + try: + output = self.outputs.get_stream(file_id, cell_id) + except (FileNotFoundError, KeyError): + self.set_status(404) + self.finish({"error": "Stream output not found."}) + else: + # self.set_header("Content-Type", "text/plain; charset=uft-8") + self.set_header("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0") + self.set_header("Pragma", "no-cache") + self.set_header("Expires", "0") + self.set_status(200) + self.write(output) + self.finish(set_content_type="text/plain; charset=utf-8") + + +# ----------------------------------------------------------------------------- +# URL to handler mappings +# ----------------------------------------------------------------------------- + +# Strict UUID regex (matches standard 8-4-4-4-12 UUIDs) +_uuid_regex = r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" + +_file_id_regex = rf"(?P{_uuid_regex})" +_cell_id_regex = rf"(?P{_uuid_regex})" + +# non-negative integers +_output_index_regex = r"(?P0|[1-9]\d*)" + +def setup_handlers(web_app): + """Setup the handlers for the outputs service.""" + + handlers = [ + (rf"/api/outputs/{_file_id_regex}/{_cell_id_regex}/{_output_index_regex}.output", OutputsAPIHandler), + (rf"/api/outputs/{_file_id_regex}/{_cell_id_regex}/stream", StreamAPIHandler), + ] + + base_url = web_app.settings["base_url"] + new_handlers = [] + for handler in handlers: + pattern = url_path_join(base_url, handler[0]) + new_handler = (pattern, *handler[1:]) + new_handlers.append(new_handler) + + # Add the handler for all hosts + web_app.add_handlers(".*$", new_handlers) + + + +# import time +# for i in range(100): +# print(40*str(i)) +# time.sleep(1.0) \ No newline at end of file diff --git a/jupyter_rtc_core/outputs/manager.py b/jupyter_rtc_core/outputs/manager.py new file mode 100644 index 0000000..f4ccfdf --- /dev/null +++ b/jupyter_rtc_core/outputs/manager.py @@ -0,0 +1,100 @@ +import json +import os +from pathlib import Path, PurePath +import shutil + + +from traitlets.config.configurable import LoggingConfigurable +from traitlets import ( + Any, + Bool, + Dict, + Instance, + List, + TraitError, + Type, + Unicode, + default, + validate, +) + +from jupyter_core.paths import jupyter_runtime_dir + +class OutputsManager(LoggingConfigurable): + + outputs_path = Instance(PurePath, help="The local runtime dir") + _last_output_index = Dict(default_value={}) + + @default("outputs_path") + def _default_outputs_path(self): + return Path(jupyter_runtime_dir()) / "outputs" + + def _ensure_path(self, file_id, cell_id): + nested_dir = self.outputs_path / file_id / cell_id + self.log.info(f"Creating directory: {nested_dir}") + nested_dir.mkdir(parents=True, exist_ok=True) + + def _build_path(self, file_id, cell_id=None, output_index=None): + path = self.outputs_path / file_id + if cell_id is not None: + path = path / cell_id + if output_index is not None: + path = path / f"{output_index}.output" + return path + + def get(self, file_id, cell_id, output_index): + """Get an outputs by file_id, cell_id, and output_index.""" + self.log.info(f"OutputsManager.get: {file_id} {cell_id} {output_index}") + path = self._build_path(file_id, cell_id, output_index) + if not os.path.isfile(path): + raise FileNotFoundError(f"The output file doesn't exist: {path}") + with open(path, "r", encoding="utf-8") as f: + output = json.loads(f.read()) + return output + + def get_stream(self, file_id, cell_id): + "Get the stream output for a cell by file_id and cell_id." + path = self._build_path(file_id, cell_id) / "stream" + if not os.path.isfile(path): + raise FileNotFoundError(f"The output file doesn't exist: {path}") + with open(path, "r", encoding="utf-8") as f: + output = f.read() + return output + + def write(self, file_id, cell_id, output): + """Write a new output for file_id and cell_id.""" + self.log.info(f"OutputsManager.write: {file_id} {cell_id} {output}") + if output["output_type"] == "stream": + url = self.write_stream(file_id, cell_id, output) + else: + url = self.write_output(file_id, cell_id, output) + return url + + def write_output(self, file_id, cell_id, output): + self._ensure_path(file_id, cell_id) + last_index = self._last_output_index.get(cell_id, -1) + index = last_index + 1 + self._last_output_index[cell_id] = index + path = self._build_path(file_id, cell_id, index) + data = json.dumps(output, ensure_ascii=False) + with open(path, "w", encoding="utf-8") as f: + f.write(data) + url = f"/api/outputs/{file_id}/{cell_id}/{index}.output" + return url + + def write_stream(self, file_id, cell_id, output): + self._ensure_path(file_id, cell_id) + path = self._build_path(file_id, cell_id) / "stream" + text = output["text"] + mode = 'a' if os.path.isfile(path) else 'w' + with open(path, "a", encoding="utf-8") as f: + f.write(text) + url = f"/api/outputs/{file_id}/{cell_id}/stream" + return url + + def clear(self, file_id, cell_id=None): + path = self._build_path(file_id, cell_id) + shutil.rmtree(path) + + +