diff --git a/jupyter_rtc_core/__init__.py b/jupyter_rtc_core/__init__.py
index 0f7737d..9cb3923 100644
--- a/jupyter_rtc_core/__init__.py
+++ b/jupyter_rtc_core/__init__.py
@@ -7,7 +7,13 @@
import warnings
warnings.warn("Importing 'jupyter_rtc_core' outside a proper installation.")
__version__ = "dev"
+
+from traitlets.config import Config
+
from .handlers import setup_handlers
+from .outputs.connection import RTCWebsocketConnection
+from .outputs.handlers import setup_handlers as setup_output_handlers
+from .outputs.manager import OutputsManager
def _jupyter_labextension_paths():
@@ -23,6 +29,14 @@ def _jupyter_server_extension_points():
}]
+def _link_jupyter_server_extension(server_app):
+ """Setup custom config needed by this extension."""
+ server_app.kernel_websocket_connection_class = RTCWebsocketConnection
+ c = Config()
+ c.ServerApp.kernel_websocket_connection_class = "jupyter_rtc_core.outputs.connection.RTCWebsocketConnection"
+ server_app.update_config(c)
+
+
def _load_jupyter_server_extension(server_app):
"""Registers the API handler to receive HTTP requests from the frontend extension.
@@ -32,5 +46,9 @@ def _load_jupyter_server_extension(server_app):
JupyterLab application instance
"""
setup_handlers(server_app.web_app)
+ setup_output_handlers(server_app.web_app)
+ om = OutputsManager(config=server_app.config)
+ server_app.web_app.settings["outputs_manager"] = om
+
name = "jupyter_rtc_core"
server_app.log.info(f"Registered {name} server extension")
diff --git a/jupyter_rtc_core/outputs/__init__.py b/jupyter_rtc_core/outputs/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/jupyter_rtc_core/outputs/connection.py b/jupyter_rtc_core/outputs/connection.py
new file mode 100644
index 0000000..bd4f713
--- /dev/null
+++ b/jupyter_rtc_core/outputs/connection.py
@@ -0,0 +1,298 @@
+import asyncio
+import json
+import typing as t
+
+from pycrdt import Map
+
+from traitlets import Dict, Unicode
+
+from jupyter_server.services.kernels.connection.channels import ZMQChannelsWebsocketConnection
+from jupyter_server.services.kernels.connection.base import (
+ deserialize_binary_message,
+ deserialize_msg_from_ws_v1
+)
+
+class RTCWebsocketConnection(ZMQChannelsWebsocketConnection):
+
+ _cell_ids = Dict(default_value={})
+ _msg_ids = Dict(default_value={})
+ _cell_indices = Dict(default_value={})
+ file_id = Unicode(allow_none=True, default_value=None)
+
+ @property
+ def outputs_manager(self):
+ return self.websocket_handler.settings["outputs_manager"]
+
+ @property
+ def session_manager(self):
+ return self.websocket_handler.settings["session_manager"]
+
+ @property
+ def file_id_manager(self):
+ return self.websocket_handler.settings["file_id_manager"]
+
+ @property
+ def jupyter_server_ydoc(self):
+ return self.websocket_handler.settings["jupyter_server_ydoc"]
+
+ def get_part(self, field, value, msg_list):
+ """Get a part of a message."""
+ if value is None:
+ field2idx = {
+ "header": 0,
+ "parent_header": 1,
+ "metadata": 2,
+ "content": 3,
+ }
+ value = self.session.unpack(msg_list[field2idx[field]])
+ return value
+
+ def save_cell_id(self, channel, msg, msg_list):
+ """Save the cell_id <-> msg_id map.
+
+ This method is used to create a map between cell_id and msg_id.
+ Incoming execute_request messages have both a cell_id and msg_id.
+ When output messages are send back to the frontend, this map is used
+ to find the cell_id for a given parent msg_id.
+ """
+ if channel != "shell":
+ return
+ header = self.get_part("header", msg.get("header"), msg_list)
+ if header is None:
+ return
+ if header["msg_type"] != "execute_request":
+ return
+ msg_id = header["msg_id"]
+ md = self.get_part("metadata", msg.get("metadata"), msg_list)
+ if md is None:
+ return
+ cell_id = md.get('cellId')
+ existing_msg_id = self.get_msg_id(cell_id)
+
+ if existing_msg_id != msg_id:
+ if self.file_id is not None:
+ self.log.info(f"Cell has been rerun, removing old outputs: {self.file_id} {cell_id}")
+ self.outputs_manager.clear(file_id=self.file_id, cell_id=cell_id)
+ self.log.info(f"Saving (msg_id, cell_id): ({msg_id} {cell_id})")
+ self._cell_ids[msg_id] = cell_id
+ self._msg_ids[cell_id] = msg_id
+
+ def get_cell_id(self, msg_id):
+ """Retrieve a cell_id from a parent msg_id."""
+ return self._cell_ids.get(msg_id)
+
+ def get_msg_id(self, cell_id):
+ return self._msg_ids.get(cell_id)
+
+ def disconnect(self):
+ if self.file_id is not None:
+ self.log.info(f"Removing server outputs: {self.file_id}")
+ self.outputs_manager.clear(file_id=self.file_id)
+ super().disconnect()
+
+ def handle_incoming_message(self, incoming_msg: str) -> None:
+ """Handle incoming messages from Websocket to ZMQ Sockets."""
+ ws_msg = incoming_msg
+ if not self.channels:
+ # already closed, ignore the message
+ self.log.debug("Received message on closed websocket %r", ws_msg)
+ return
+
+ if self.subprotocol == "v1.kernel.websocket.jupyter.org":
+ channel, msg_list = deserialize_msg_from_ws_v1(ws_msg)
+ msg = {
+ "header": None,
+ }
+ else:
+ if isinstance(ws_msg, bytes): # type:ignore[unreachable]
+ msg = deserialize_binary_message(ws_msg) # type:ignore[unreachable]
+ else:
+ msg = json.loads(ws_msg)
+ msg_list = []
+ channel = msg.pop("channel", None)
+
+ if channel is None:
+ self.log.warning("No channel specified, assuming shell: %s", msg)
+ channel = "shell"
+ if channel not in self.channels:
+ self.log.warning("No such channel: %r", channel)
+ return
+ am = self.multi_kernel_manager.allowed_message_types
+ ignore_msg = False
+ if am:
+ msg["header"] = self.get_part("header", msg["header"], msg_list)
+ assert msg["header"] is not None
+ if msg["header"]["msg_type"] not in am: # type:ignore[unreachable]
+ self.log.warning(
+ 'Received message of type "%s", which is not allowed. Ignoring.'
+ % msg["header"]["msg_type"]
+ )
+ ignore_msg = True
+
+ # Persist the map between cell_id and msg_id
+ self.save_cell_id(channel, msg, msg_list)
+
+ if not ignore_msg:
+ stream = self.channels[channel]
+ if self.subprotocol == "v1.kernel.websocket.jupyter.org":
+ self.session.send_raw(stream, msg_list)
+ else:
+ self.session.send(stream, msg)
+
+
+ async def handle_outgoing_message(self, stream: str, outgoing_msg: list[t.Any]) -> None:
+ """Handle the outgoing messages from ZMQ sockets to Websocket."""
+ msg_list = outgoing_msg
+ _, fed_msg_list = self.session.feed_identities(msg_list)
+
+ if self.subprotocol == "v1.kernel.websocket.jupyter.org":
+ msg = {"header": None, "parent_header": None, "content": None}
+ else:
+ msg = self.session.deserialize(fed_msg_list)
+
+ if isinstance(stream, str):
+ stream = self.channels[stream]
+
+ channel = getattr(stream, "channel", None)
+ parts = fed_msg_list[1:]
+
+ self._on_error(channel, msg, parts)
+
+ # Handle output messages
+ header = self.get_part("header", msg.get("header"), parts)
+ msg_type = header["msg_type"]
+ if msg_type in ("stream", "display_data", "execute_result", "error"):
+ self.handle_output(msg_type, msg, parts)
+ return
+
+ # We can probably get rid of the rate limiting
+ if self._limit_rate(channel, msg, parts):
+ return
+
+ if self.subprotocol == "v1.kernel.websocket.jupyter.org":
+ self._on_zmq_reply(stream, parts)
+ else:
+ self._on_zmq_reply(stream, msg)
+
+ def handle_output(self, msg_type, msg, parts):
+ """Handle output messages by writing them to the server side Ydoc."""
+ parent_header = self.get_part("parent_header", msg.get("parent_header"), parts)
+ msg_id = parent_header["msg_id"]
+ self.log.info(f"handle_output: {msg_id}")
+ cell_id = self.get_cell_id(msg_id)
+ if cell_id is None:
+ return
+ self.log.info(f"Retreiving (msg_id, cell_id): ({msg_id} {cell_id})")
+ content = self.get_part("content", msg.get("content"), parts)
+ self.log.info(f"{cell_id} {msg_type} {content}")
+ asyncio.create_task(self.output_task(msg_type, cell_id, content))
+
+ async def output_task(self, msg_type, cell_id, content):
+ """A coroutine to handle output messages."""
+ kernel_session_manager = self.session_manager
+ try:
+ kernel_session = await kernel_session_manager.get_session(kernel_id=self.kernel_id)
+ except:
+ pass
+ path = kernel_session["path"]
+
+ file_id = self.file_id_manager.get_id(path)
+ if file_id is None:
+ return
+ self.file_id = file_id
+ try:
+ notebook = await self.jupyter_server_ydoc.get_document(path=path, copy=False, file_format='json', content_type='notebook')
+ except:
+ return
+ cells = notebook.ycells
+
+ cell_index, target_cell = self.find_cell(cell_id, cells)
+ if target_cell is None:
+ return
+
+ # Convert from the message spec to the nbformat output structure
+ output = self.transform_output(msg_type, content, ydoc=False)
+ output_url = self.outputs_manager.write(file_id, cell_id, output)
+ nb_output = Map({
+ "output_type": "display_data",
+ "data": {
+ 'text/html': f'Output'
+ },
+ "metadata": {
+ "outputs_service": True
+ }
+ })
+ target_cell["outputs"].append(nb_output)
+
+ def find_cell(self, cell_id, cells):
+ # Find the target_cell and its cell_index and cache
+ target_cell = None
+ cell_index = None
+ try:
+ # See if we have a cached value for the cell_index
+ cell_index = self._cell_indices[cell_id]
+ target_cell = cells[cell_index]
+ except KeyError:
+ # Do a linear scan to find the cell
+ self.log.info(f"Linear scan: {cell_id}")
+ cell_index, target_cell = self.scan_cells(cell_id, cells)
+ else:
+ # Verify that the cached value still matches
+ if target_cell["id"] != cell_id:
+ self.log.info(f"Invalid cache hit: {cell_id}")
+ cell_index, target_cell = self.scan_cells(cell_id, cells)
+ else:
+ self.log.info(f"Validated cache hit: {cell_id}")
+ return cell_index, target_cell
+
+ def scan_cells(self, cell_id, cells):
+ """Find the cell with a given cell_id.
+
+ This does a simple linear scan of the cells, but in reverse order because
+ we believe that users are more often running cells towards the end of the
+ notebook.
+ """
+ target_cell = None
+ cell_index = None
+ for i in reversed(range(0, len(cells))):
+ cell = cells[i]
+ if cell["id"] == cell_id:
+ target_cell = cell
+ cell_index = i
+ self._cell_indices[cell_id] = cell_index
+ break
+ return cell_index, target_cell
+
+ def transform_output(self, msg_type, content, ydoc=False):
+ """Transform output from IOPub messages to the nbformat specification."""
+ if ydoc:
+ factory = Map
+ else:
+ factory = lambda x: x
+ if msg_type == "stream":
+ output = factory({
+ "output_type": "stream",
+ "text": content["text"],
+ "name": content["name"]
+ })
+ elif msg_type == "display_data":
+ output = factory({
+ "output_type": "display_data",
+ "data": content["data"],
+ "metadata": content["metadata"]
+ })
+ elif msg_type == "execute_result":
+ output = factory({
+ "output_type": "execute_result",
+ "data": content["data"],
+ "metadata": content["metadata"],
+ "execution_count": content["execution_count"]
+ })
+ elif msg_type == "error":
+ output = factory({
+ "output_type": "error",
+ "traceback": content["traceback"],
+ "ename": content["ename"],
+ "evalue": content["evalue"]
+ })
+ return output
diff --git a/jupyter_rtc_core/outputs/handlers.py b/jupyter_rtc_core/outputs/handlers.py
new file mode 100644
index 0000000..8e37f9a
--- /dev/null
+++ b/jupyter_rtc_core/outputs/handlers.py
@@ -0,0 +1,99 @@
+# Copyright (c) Jupyter Development Team.
+# Distributed under the terms of the Modified BSD License.
+
+import json
+
+from tornado import web
+
+from jupyter_server.auth.decorator import authorized
+from jupyter_server.base.handlers import APIHandler
+from jupyter_server.utils import url_path_join
+
+
+class OutputsAPIHandler(APIHandler):
+ """An outputs service API handler."""
+
+ auth_resource = "outputs"
+
+ @property
+ def outputs(self):
+ return self.settings["outputs_manager"]
+
+ @web.authenticated
+ @authorized
+ async def get(self, file_id=None, cell_id=None, output_index=None):
+ try:
+ output = self.outputs.get(file_id, cell_id, output_index)
+ except (FileNotFoundError, KeyError):
+ self.set_status(404)
+ self.finish({"error": "Output not found."})
+ else:
+ self.set_status(200)
+ self.set_header("Content-Type", "application/json")
+ self.write(output)
+
+
+class StreamAPIHandler(APIHandler):
+ """An outputs service API handler."""
+
+ auth_resource = "outputs"
+
+ @property
+ def outputs(self):
+ return self.settings["outputs_manager"]
+
+ @web.authenticated
+ @authorized
+ async def get(self, file_id=None, cell_id=None):
+ try:
+ output = self.outputs.get_stream(file_id, cell_id)
+ except (FileNotFoundError, KeyError):
+ self.set_status(404)
+ self.finish({"error": "Stream output not found."})
+ else:
+ # self.set_header("Content-Type", "text/plain; charset=uft-8")
+ self.set_header("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0")
+ self.set_header("Pragma", "no-cache")
+ self.set_header("Expires", "0")
+ self.set_status(200)
+ self.write(output)
+ self.finish(set_content_type="text/plain; charset=utf-8")
+
+
+# -----------------------------------------------------------------------------
+# URL to handler mappings
+# -----------------------------------------------------------------------------
+
+# Strict UUID regex (matches standard 8-4-4-4-12 UUIDs)
+_uuid_regex = r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
+
+_file_id_regex = rf"(?P{_uuid_regex})"
+_cell_id_regex = rf"(?P{_uuid_regex})"
+
+# non-negative integers
+_output_index_regex = r"(?P0|[1-9]\d*)"
+
+def setup_handlers(web_app):
+ """Setup the handlers for the outputs service."""
+
+ handlers = [
+ (rf"/api/outputs/{_file_id_regex}/{_cell_id_regex}/{_output_index_regex}.output", OutputsAPIHandler),
+ (rf"/api/outputs/{_file_id_regex}/{_cell_id_regex}/stream", StreamAPIHandler),
+ ]
+
+ base_url = web_app.settings["base_url"]
+ new_handlers = []
+ for handler in handlers:
+ pattern = url_path_join(base_url, handler[0])
+ new_handler = (pattern, *handler[1:])
+ new_handlers.append(new_handler)
+
+ # Add the handler for all hosts
+ web_app.add_handlers(".*$", new_handlers)
+
+
+
+# import time
+# for i in range(100):
+# print(40*str(i))
+# time.sleep(1.0)
\ No newline at end of file
diff --git a/jupyter_rtc_core/outputs/manager.py b/jupyter_rtc_core/outputs/manager.py
new file mode 100644
index 0000000..f4ccfdf
--- /dev/null
+++ b/jupyter_rtc_core/outputs/manager.py
@@ -0,0 +1,100 @@
+import json
+import os
+from pathlib import Path, PurePath
+import shutil
+
+
+from traitlets.config.configurable import LoggingConfigurable
+from traitlets import (
+ Any,
+ Bool,
+ Dict,
+ Instance,
+ List,
+ TraitError,
+ Type,
+ Unicode,
+ default,
+ validate,
+)
+
+from jupyter_core.paths import jupyter_runtime_dir
+
+class OutputsManager(LoggingConfigurable):
+
+ outputs_path = Instance(PurePath, help="The local runtime dir")
+ _last_output_index = Dict(default_value={})
+
+ @default("outputs_path")
+ def _default_outputs_path(self):
+ return Path(jupyter_runtime_dir()) / "outputs"
+
+ def _ensure_path(self, file_id, cell_id):
+ nested_dir = self.outputs_path / file_id / cell_id
+ self.log.info(f"Creating directory: {nested_dir}")
+ nested_dir.mkdir(parents=True, exist_ok=True)
+
+ def _build_path(self, file_id, cell_id=None, output_index=None):
+ path = self.outputs_path / file_id
+ if cell_id is not None:
+ path = path / cell_id
+ if output_index is not None:
+ path = path / f"{output_index}.output"
+ return path
+
+ def get(self, file_id, cell_id, output_index):
+ """Get an outputs by file_id, cell_id, and output_index."""
+ self.log.info(f"OutputsManager.get: {file_id} {cell_id} {output_index}")
+ path = self._build_path(file_id, cell_id, output_index)
+ if not os.path.isfile(path):
+ raise FileNotFoundError(f"The output file doesn't exist: {path}")
+ with open(path, "r", encoding="utf-8") as f:
+ output = json.loads(f.read())
+ return output
+
+ def get_stream(self, file_id, cell_id):
+ "Get the stream output for a cell by file_id and cell_id."
+ path = self._build_path(file_id, cell_id) / "stream"
+ if not os.path.isfile(path):
+ raise FileNotFoundError(f"The output file doesn't exist: {path}")
+ with open(path, "r", encoding="utf-8") as f:
+ output = f.read()
+ return output
+
+ def write(self, file_id, cell_id, output):
+ """Write a new output for file_id and cell_id."""
+ self.log.info(f"OutputsManager.write: {file_id} {cell_id} {output}")
+ if output["output_type"] == "stream":
+ url = self.write_stream(file_id, cell_id, output)
+ else:
+ url = self.write_output(file_id, cell_id, output)
+ return url
+
+ def write_output(self, file_id, cell_id, output):
+ self._ensure_path(file_id, cell_id)
+ last_index = self._last_output_index.get(cell_id, -1)
+ index = last_index + 1
+ self._last_output_index[cell_id] = index
+ path = self._build_path(file_id, cell_id, index)
+ data = json.dumps(output, ensure_ascii=False)
+ with open(path, "w", encoding="utf-8") as f:
+ f.write(data)
+ url = f"/api/outputs/{file_id}/{cell_id}/{index}.output"
+ return url
+
+ def write_stream(self, file_id, cell_id, output):
+ self._ensure_path(file_id, cell_id)
+ path = self._build_path(file_id, cell_id) / "stream"
+ text = output["text"]
+ mode = 'a' if os.path.isfile(path) else 'w'
+ with open(path, "a", encoding="utf-8") as f:
+ f.write(text)
+ url = f"/api/outputs/{file_id}/{cell_id}/stream"
+ return url
+
+ def clear(self, file_id, cell_id=None):
+ path = self._build_path(file_id, cell_id)
+ shutil.rmtree(path)
+
+
+