Skip to content
18 changes: 18 additions & 0 deletions jupyter_rtc_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
import warnings
warnings.warn("Importing 'jupyter_rtc_core' outside a proper installation.")
__version__ = "dev"

from traitlets.config import Config

from .handlers import setup_handlers
from .outputs.connection import RTCWebsocketConnection
from .outputs.handlers import setup_handlers as setup_output_handlers
from .outputs.manager import OutputsManager


def _jupyter_labextension_paths():
Expand All @@ -23,6 +29,14 @@ def _jupyter_server_extension_points():
}]


def _link_jupyter_server_extension(server_app):
"""Setup custom config needed by this extension."""
server_app.kernel_websocket_connection_class = RTCWebsocketConnection
c = Config()
c.ServerApp.kernel_websocket_connection_class = "jupyter_rtc_core.outputs.connection.RTCWebsocketConnection"
server_app.update_config(c)


def _load_jupyter_server_extension(server_app):
"""Registers the API handler to receive HTTP requests from the frontend extension.

Expand All @@ -32,5 +46,9 @@ def _load_jupyter_server_extension(server_app):
JupyterLab application instance
"""
setup_handlers(server_app.web_app)
setup_output_handlers(server_app.web_app)
om = OutputsManager(config=server_app.config)
server_app.web_app.settings["outputs_manager"] = om

name = "jupyter_rtc_core"
server_app.log.info(f"Registered {name} server extension")
Empty file.
298 changes: 298 additions & 0 deletions jupyter_rtc_core/outputs/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
import asyncio
import json
import typing as t

from pycrdt import Map

from traitlets import Dict, Unicode

from jupyter_server.services.kernels.connection.channels import ZMQChannelsWebsocketConnection
from jupyter_server.services.kernels.connection.base import (
deserialize_binary_message,
deserialize_msg_from_ws_v1
)

class RTCWebsocketConnection(ZMQChannelsWebsocketConnection):

_cell_ids = Dict(default_value={})
_msg_ids = Dict(default_value={})
_cell_indices = Dict(default_value={})
file_id = Unicode(allow_none=True, default_value=None)

@property
def outputs_manager(self):
return self.websocket_handler.settings["outputs_manager"]

@property
def session_manager(self):
return self.websocket_handler.settings["session_manager"]

@property
def file_id_manager(self):
return self.websocket_handler.settings["file_id_manager"]

@property
def jupyter_server_ydoc(self):
return self.websocket_handler.settings["jupyter_server_ydoc"]

def get_part(self, field, value, msg_list):
"""Get a part of a message."""
if value is None:
field2idx = {
"header": 0,
"parent_header": 1,
"metadata": 2,
"content": 3,
}
value = self.session.unpack(msg_list[field2idx[field]])
return value

def save_cell_id(self, channel, msg, msg_list):
"""Save the cell_id <-> msg_id map.

This method is used to create a map between cell_id and msg_id.
Incoming execute_request messages have both a cell_id and msg_id.
When output messages are send back to the frontend, this map is used
to find the cell_id for a given parent msg_id.
"""
if channel != "shell":
return
header = self.get_part("header", msg.get("header"), msg_list)
if header is None:
return
if header["msg_type"] != "execute_request":
return
msg_id = header["msg_id"]
md = self.get_part("metadata", msg.get("metadata"), msg_list)
if md is None:
return
cell_id = md.get('cellId')
existing_msg_id = self.get_msg_id(cell_id)

if existing_msg_id != msg_id:
if self.file_id is not None:
self.log.info(f"Cell has been rerun, removing old outputs: {self.file_id} {cell_id}")
self.outputs_manager.clear(file_id=self.file_id, cell_id=cell_id)
self.log.info(f"Saving (msg_id, cell_id): ({msg_id} {cell_id})")
self._cell_ids[msg_id] = cell_id
self._msg_ids[cell_id] = msg_id

def get_cell_id(self, msg_id):
"""Retrieve a cell_id from a parent msg_id."""
return self._cell_ids.get(msg_id)

def get_msg_id(self, cell_id):
return self._msg_ids.get(cell_id)

def disconnect(self):
if self.file_id is not None:
self.log.info(f"Removing server outputs: {self.file_id}")
self.outputs_manager.clear(file_id=self.file_id)
super().disconnect()

def handle_incoming_message(self, incoming_msg: str) -> None:
"""Handle incoming messages from Websocket to ZMQ Sockets."""
ws_msg = incoming_msg
if not self.channels:
# already closed, ignore the message
self.log.debug("Received message on closed websocket %r", ws_msg)
return

if self.subprotocol == "v1.kernel.websocket.jupyter.org":
channel, msg_list = deserialize_msg_from_ws_v1(ws_msg)
msg = {
"header": None,
}
else:
if isinstance(ws_msg, bytes): # type:ignore[unreachable]
msg = deserialize_binary_message(ws_msg) # type:ignore[unreachable]
else:
msg = json.loads(ws_msg)
msg_list = []
channel = msg.pop("channel", None)

if channel is None:
self.log.warning("No channel specified, assuming shell: %s", msg)
channel = "shell"
if channel not in self.channels:
self.log.warning("No such channel: %r", channel)
return
am = self.multi_kernel_manager.allowed_message_types
ignore_msg = False
if am:
msg["header"] = self.get_part("header", msg["header"], msg_list)
assert msg["header"] is not None
if msg["header"]["msg_type"] not in am: # type:ignore[unreachable]
self.log.warning(
'Received message of type "%s", which is not allowed. Ignoring.'
% msg["header"]["msg_type"]
)
ignore_msg = True

# Persist the map between cell_id and msg_id
self.save_cell_id(channel, msg, msg_list)

if not ignore_msg:
stream = self.channels[channel]
if self.subprotocol == "v1.kernel.websocket.jupyter.org":
self.session.send_raw(stream, msg_list)
else:
self.session.send(stream, msg)


async def handle_outgoing_message(self, stream: str, outgoing_msg: list[t.Any]) -> None:
"""Handle the outgoing messages from ZMQ sockets to Websocket."""
msg_list = outgoing_msg
_, fed_msg_list = self.session.feed_identities(msg_list)

if self.subprotocol == "v1.kernel.websocket.jupyter.org":
msg = {"header": None, "parent_header": None, "content": None}
else:
msg = self.session.deserialize(fed_msg_list)

if isinstance(stream, str):
stream = self.channels[stream]

channel = getattr(stream, "channel", None)
parts = fed_msg_list[1:]

self._on_error(channel, msg, parts)

# Handle output messages
header = self.get_part("header", msg.get("header"), parts)
msg_type = header["msg_type"]
if msg_type in ("stream", "display_data", "execute_result", "error"):
self.handle_output(msg_type, msg, parts)
return

# We can probably get rid of the rate limiting
if self._limit_rate(channel, msg, parts):
return

if self.subprotocol == "v1.kernel.websocket.jupyter.org":
self._on_zmq_reply(stream, parts)
else:
self._on_zmq_reply(stream, msg)

def handle_output(self, msg_type, msg, parts):
"""Handle output messages by writing them to the server side Ydoc."""
parent_header = self.get_part("parent_header", msg.get("parent_header"), parts)
msg_id = parent_header["msg_id"]
self.log.info(f"handle_output: {msg_id}")
cell_id = self.get_cell_id(msg_id)
if cell_id is None:
return
self.log.info(f"Retreiving (msg_id, cell_id): ({msg_id} {cell_id})")
content = self.get_part("content", msg.get("content"), parts)
self.log.info(f"{cell_id} {msg_type} {content}")
asyncio.create_task(self.output_task(msg_type, cell_id, content))

async def output_task(self, msg_type, cell_id, content):
"""A coroutine to handle output messages."""
kernel_session_manager = self.session_manager
try:
kernel_session = await kernel_session_manager.get_session(kernel_id=self.kernel_id)
except:
pass
path = kernel_session["path"]

file_id = self.file_id_manager.get_id(path)
if file_id is None:
return
self.file_id = file_id
try:
notebook = await self.jupyter_server_ydoc.get_document(path=path, copy=False, file_format='json', content_type='notebook')
except:
return
cells = notebook.ycells

cell_index, target_cell = self.find_cell(cell_id, cells)
if target_cell is None:
return

# Convert from the message spec to the nbformat output structure
output = self.transform_output(msg_type, content, ydoc=False)
output_url = self.outputs_manager.write(file_id, cell_id, output)
nb_output = Map({
"output_type": "display_data",
"data": {
'text/html': f'<a href="{output_url}">Output</a>'
},
"metadata": {
"outputs_service": True
}
})
target_cell["outputs"].append(nb_output)

def find_cell(self, cell_id, cells):
# Find the target_cell and its cell_index and cache
target_cell = None
cell_index = None
try:
# See if we have a cached value for the cell_index
cell_index = self._cell_indices[cell_id]
target_cell = cells[cell_index]
except KeyError:
# Do a linear scan to find the cell
self.log.info(f"Linear scan: {cell_id}")
cell_index, target_cell = self.scan_cells(cell_id, cells)
else:
# Verify that the cached value still matches
if target_cell["id"] != cell_id:
self.log.info(f"Invalid cache hit: {cell_id}")
cell_index, target_cell = self.scan_cells(cell_id, cells)
else:
self.log.info(f"Validated cache hit: {cell_id}")
return cell_index, target_cell

def scan_cells(self, cell_id, cells):
"""Find the cell with a given cell_id.

This does a simple linear scan of the cells, but in reverse order because
we believe that users are more often running cells towards the end of the
notebook.
"""
target_cell = None
cell_index = None
for i in reversed(range(0, len(cells))):
cell = cells[i]
if cell["id"] == cell_id:
target_cell = cell
cell_index = i
self._cell_indices[cell_id] = cell_index
break
return cell_index, target_cell

def transform_output(self, msg_type, content, ydoc=False):
"""Transform output from IOPub messages to the nbformat specification."""
if ydoc:
factory = Map
else:
factory = lambda x: x
if msg_type == "stream":
output = factory({
"output_type": "stream",
"text": content["text"],
"name": content["name"]
})
elif msg_type == "display_data":
output = factory({
"output_type": "display_data",
"data": content["data"],
"metadata": content["metadata"]
})
elif msg_type == "execute_result":
output = factory({
"output_type": "execute_result",
"data": content["data"],
"metadata": content["metadata"],
"execution_count": content["execution_count"]
})
elif msg_type == "error":
output = factory({
"output_type": "error",
"traceback": content["traceback"],
"ename": content["ename"],
"evalue": content["evalue"]
})
return output
Loading