Skip to content

Commit 0c7b452

Browse files
authored

File tree

12 files changed

+477
-3
lines changed

12 files changed

+477
-3
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ This extension is composed of a Python package named `jupyter_rtc_core`
88
for the server extension and a NPM package named `@jupyter/rtc-core`
99
for the frontend extension.
1010

11+
## Try it out
12+
13+
Run with the proper configuration
14+
```
15+
jupyter lab --config jupyter_config.py
16+
```
17+
18+
1119
## Requirements
1220

1321
- JupyterLab >= 4.0.0

jupyter_rtc_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
warnings.warn("Importing 'jupyter_rtc_core' outside a proper installation.")
99
__version__ = "dev"
1010

11+
1112
from .app import RtcExtensionApp
1213

1314

jupyter_rtc_core/app.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from jupyter_server.extension.application import ExtensionApp
2+
from traitlets.config import Config
23

34
from .handlers import RouteHandler
45
from .websockets import GlobalAwarenessWebsocket, YRoomWebsocket
@@ -13,11 +14,20 @@ class RtcExtensionApp(ExtensionApp):
1314
# this can be deleted prior to initial release.
1415
(r"jupyter-rtc-core/get-example/?", RouteHandler),
1516
# global awareness websocket
16-
(r"api/collaboration/room/JupyterLab:globalAwareness/?", GlobalAwarenessWebsocket),
17-
# ydoc websocket
18-
(r"api/collaboration/room/(.*)", YRoomWebsocket)
17+
# (r"api/collaboration/room/JupyterLab:globalAwareness/?", GlobalAwarenessWebsocket),
18+
# # ydoc websocket
19+
# (r"api/collaboration/room/(.*)", YRoomWebsocket)
1920
]
2021

2122
def initialize(self):
2223
super().initialize()
2324

25+
26+
def _link_jupyter_server_extension(self, server_app):
27+
"""Setup custom config needed by this extension."""
28+
c = Config()
29+
c.ServerApp.kernel_websocket_connection_class = "jupyter_rtc_core.kernels.websocket_connection.NextGenKernelWebsocketConnection"
30+
c.ServerApp.kernel_manager_class = "jupyter_rtc_core.kernels.multi_kernel_manager.NextGenMappingKernelManager"
31+
c.MultiKernelManager.kernel_manager_class = "jupyter_rtc_core.kernels.kernel_manager.NextGenKernelManager"
32+
server_app.update_config(c)
33+
super()._link_jupyter_server_extension(server_app)

jupyter_rtc_core/handlers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@ def get(self):
1212
self.finish(json.dumps({
1313
"data": "This is /jupyter-rtc-core/get-example endpoint!"
1414
}))
15+
16+
17+

jupyter_rtc_core/kernels/__init__.py

Whitespace-only changes.
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import asyncio
2+
import json
3+
import typing as t
4+
from traitlets import Set
5+
from traitlets import Instance
6+
from traitlets import Any
7+
from .utils import LRUCache
8+
from jupyter_client.asynchronous.client import AsyncKernelClient
9+
import anyio
10+
11+
from jupyter_client.session import Session
12+
13+
class NextGenAsyncKernelClient(AsyncKernelClient):
14+
"""
15+
A ZMQ-based kernel client class that managers all listeners to a kernel.
16+
"""
17+
# Having this message cache is not ideal.
18+
# Unfortunately, we don't include the parent channel
19+
# in the messages that generate IOPub status messages, thus,
20+
# we can't differential between the control channel vs.
21+
# shell channel status. This message cache gives us
22+
# the ability to map status message back to their source.
23+
message_source_cache = Instance(
24+
default_value=LRUCache(maxsize=1000), klass=LRUCache
25+
)
26+
27+
# A set of callables that are called when a
28+
# ZMQ message comes back from the kernel.
29+
_listeners = Set(allow_none=True)
30+
31+
async def start_listening(self):
32+
"""Start listening to messages coming from the kernel.
33+
34+
Use anyio to setup a task group for listening.
35+
"""
36+
# Wrap a taskgroup so that it can be backgrounded.
37+
async def _listening():
38+
async with anyio.create_task_group() as tg:
39+
for channel_name in ["shell", "control", "stdin", "iopub"]:
40+
tg.start_soon(
41+
self._listen_for_messages, channel_name
42+
)
43+
44+
# Background this task.
45+
self._listening_task = asyncio.create_task(_listening())
46+
47+
async def stop_listening(self):
48+
# If the listening task isn't defined yet
49+
# do nothing.
50+
if not self._listening_task:
51+
return
52+
53+
# Attempt to cancel the task.
54+
self._listening_task.cancel()
55+
try:
56+
# Await cancellation.
57+
await self._listening_task
58+
except asyncio.CancelledError:
59+
self.log.info("Disconnected from client from the kernel.")
60+
# Log any exceptions that were raised.
61+
except Exception as err:
62+
self.log.error(err)
63+
64+
_listening_task: t.Optional[t.Awaitable] = Any(allow_none=True)
65+
66+
def send_message(self, channel_name, msg):
67+
"""Use the given session to send the message."""
68+
# Cache the message ID and its socket name so that
69+
# any response message can be mapped back to the
70+
# source channel.
71+
header = header = json.loads(msg[0])
72+
msg_id = header["msg_id"]
73+
self.message_source_cache[msg_id] = channel_name
74+
channel = getattr(self, f"{channel_name}_channel")
75+
channel.session.send_raw(channel.socket, msg)
76+
77+
async def recv_message(self, channel_name, msg):
78+
"""This is the main method that consumes every
79+
message coming back from the kernel. It parses the header
80+
(not the content, which might be large) and updates
81+
the last_activity, execution_state, and lifecycle_state
82+
when appropriate. Then, it routes the message
83+
to all listeners.
84+
"""
85+
# Broadcast messages
86+
async with anyio.create_task_group() as tg:
87+
# Broadcast the message to all listeners.
88+
for listener in self._listeners:
89+
async def _wrap_listener(listener, channel_name, msg):
90+
"""
91+
Wrap the listener to ensure its async and
92+
logs (instead of raises) exceptions.
93+
"""
94+
try:
95+
listener(channel_name, msg)
96+
except Exception as err:
97+
self.log.error(err)
98+
99+
tg.start_soon(_wrap_listener, listener, channel_name, msg)
100+
101+
def add_listener(self, callback: t.Callable[[dict], None]):
102+
"""Add a listener to the ZMQ Interface.
103+
104+
A listener is a callable function/method that takes
105+
the deserialized (minus the content) ZMQ message.
106+
107+
If the listener is already registered, it won't be registered again.
108+
"""
109+
self._listeners.add(callback)
110+
111+
def remove_listener(self, callback: t.Callable[[dict], None]):
112+
"""Remove a listener to teh ZMQ interface. If the listener
113+
is not found, this method does nothing.
114+
"""
115+
self._listeners.discard(callback)
116+
117+
async def _listen_for_messages(self, channel_name):
118+
"""The basic polling loop for listened to kernel messages
119+
on a ZMQ socket.
120+
"""
121+
# Wire up the ZMQ sockets
122+
# Setup up ZMQSocket broadcasting.
123+
channel = getattr(self, f"{channel_name}_channel")
124+
while True:
125+
# Wait for a message
126+
await channel.socket.poll(timeout=float("inf"))
127+
raw_msg = await channel.socket.recv_multipart()
128+
try:
129+
await self.recv_message(channel_name, raw_msg)
130+
except Exception as err:
131+
self.log.error(err)
132+
133+
def kernel_info(self):
134+
msg = self.session.msg("kernel_info_request")
135+
# Send message, skipping the delimiter and signature
136+
msg = self.session.serialize(msg)[2:]
137+
self.send_message("shell", msg)
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import typing
2+
import asyncio
3+
from traitlets import default
4+
from traitlets import Instance
5+
from traitlets import Int
6+
from traitlets import Dict
7+
from traitlets import Type
8+
from traitlets import Unicode
9+
from traitlets import validate
10+
from traitlets import observe
11+
from traitlets import Set
12+
from traitlets import TraitError
13+
from traitlets import DottedObjectName
14+
from traitlets.utils.importstring import import_item
15+
16+
from jupyter_client.manager import AsyncKernelManager
17+
18+
from . import types
19+
from . import states
20+
from .kernel_client import AsyncKernelClient
21+
22+
23+
class NextGenKernelManager(AsyncKernelManager):
24+
25+
main_client = Instance(AsyncKernelClient, allow_none=True)
26+
27+
client_class = DottedObjectName(
28+
"jupyter_rtc_core.kernels.kernel_client.NextGenAsyncKernelClient"
29+
)
30+
31+
client_factory: Type = Type(klass="jupyter_rtc_core.kernels.kernel_client.NextGenAsyncKernelClient")
32+
33+
# Configurable settings in a kernel manager that I want.
34+
time_to_connect: int = Int(
35+
default_value=10,
36+
help="The timeout for connecting to a kernel."
37+
).tag(config=True)
38+
39+
execution_state: types.EXECUTION_STATES = Unicode()
40+
41+
@validate("execution_state")
42+
def _validate_execution_state(self, proposal: dict):
43+
if not proposal["value"] in states.EXECUTION_STATES:
44+
raise TraitError(f"execution_state must be one of {states.EXECUTION_STATES}")
45+
return proposal["value"]
46+
47+
lifecycle_state: types.EXECUTION_STATES = Unicode()
48+
49+
@validate("lifecycle_state")
50+
def _validate_lifecycle_state(self, proposal: dict):
51+
if not proposal["value"] in states.LIFECYCLE_STATES:
52+
raise TraitError(f"lifecycle_state must be one of {states.LIFECYCLE_STATES}")
53+
return proposal["value"]
54+
55+
state = Dict()
56+
57+
@default('state')
58+
def _default_state(self):
59+
return {
60+
"execution_state": self.execution_state,
61+
"lifecycle_state": self.lifecycle_state
62+
}
63+
64+
@observe('execution_state')
65+
def _observer_execution_state(self, change):
66+
state = self.state
67+
state["execution_state"] = change['new']
68+
self.state = state
69+
70+
@observe('lifecycle_state')
71+
def _observer_lifecycle_state(self, change):
72+
state = self.state
73+
state["lifecycle_state"] = change['new']
74+
self.state = state
75+
76+
@validate('state')
77+
def _validate_state(self, change):
78+
value = change['value']
79+
if 'execution_state' not in value or 'lifecycle_state' not in value:
80+
TraitError("State needs to include execution_state and lifecycle_state")
81+
return value
82+
83+
@observe('state')
84+
def _state_changed(self, change):
85+
for observer in self._state_observers:
86+
observer(change["new"])
87+
88+
_state_observers = Set(allow_none=True)
89+
90+
def set_state(
91+
self,
92+
lifecycle_state: typing.Optional[types.LIFECYCLE_STATES] = None,
93+
execution_state: typing.Optional[types.EXECUTION_STATES] = None,
94+
broadcast=True
95+
):
96+
if lifecycle_state:
97+
self.lifecycle_state = lifecycle_state
98+
if execution_state:
99+
self.execution_state = execution_state
100+
101+
if broadcast:
102+
# Broadcast this state change to all listeners
103+
self.broadcast_state()
104+
105+
async def start_kernel(self, *args, **kwargs):
106+
self.set_state("starting", "starting")
107+
out = await super().start_kernel(*args, **kwargs)
108+
self.set_state("started")
109+
await self.connect()
110+
return out
111+
112+
async def shutdown_kernel(self, *args, **kwargs):
113+
self.set_state("terminating")
114+
await self.disconnect()
115+
out = await super().shutdown_kernel(*args, **kwargs)
116+
self.set_state("terminated", "dead")
117+
118+
async def restart_kernel(self, *args, **kwargs):
119+
self.set_state("restarting")
120+
return await super().restart_kernel(*args, **kwargs)
121+
122+
async def connect(self):
123+
"""Open a single client interface to the kernel.
124+
125+
Ideally this method doesn't care if the kernel
126+
is actually started. It will just try a ZMQ
127+
connection anyways and wait. This is helpful for
128+
handling 'pending' kernels, which might still
129+
be in a starting phase. We can keep a connection
130+
open regardless if the kernel is ready.
131+
"""
132+
self.set_state("connecting", "busy")
133+
# Use the new API for getting a client.
134+
self.main_client = self.client()
135+
# Track execution state by watching all messages that come through
136+
# the kernel client.
137+
self.main_client.add_listener(self.execution_state_listener)
138+
self.main_client.start_channels()
139+
await self.main_client.start_listening()
140+
# The Heartbeat channel is paused by default; unpause it here
141+
self.main_client.hb_channel.unpause()
142+
# Wait for a living heartbeat.
143+
attempt = 0
144+
while not self.main_client.hb_channel.is_alive():
145+
attempt += 1
146+
if attempt > self.time_to_connect:
147+
# Set the state to unknown.
148+
self.set_state("unknown", "unknown")
149+
raise Exception("The kernel took too long to connect to the ZMQ sockets.")
150+
# Wait a second until the next time we try again.
151+
await asyncio.sleep(1)
152+
# Send an initial kernel info request on the shell channel.
153+
self.main_client.kernel_info()
154+
self.set_state("connected")
155+
156+
async def disconnect(self):
157+
await self.main_client.stop_listening()
158+
self.main_client.stop_channels()
159+
160+
def broadcast_state(self):
161+
"""Broadcast state to all listeners"""
162+
if not self.main_client:
163+
return
164+
165+
# Emit this state to all listeners
166+
for listener in self.main_client._listeners:
167+
# Manufacture a status message
168+
session = self.main_client.session
169+
msg = session.msg("status", {"execution_state": self.execution_state})
170+
msg = session.serialize(msg)
171+
listener("iopub", msg)
172+
173+
def execution_state_listener(self, channel_name, msg):
174+
"""Set the execution state by watching messages returned by the shell channel."""
175+
# Only continue if we're on the IOPub where the status is published.
176+
if channel_name != "iopub":
177+
return
178+
session = self.main_client.session
179+
_, smsg = session.feed_identities(msg)
180+
# Unpack the message
181+
deserialized_msg = session.deserialize(smsg, content=False)
182+
if deserialized_msg["msg_type"] == "status":
183+
content = session.unpack(deserialized_msg["content"])
184+
status = content["execution_state"]
185+
if status == "starting":
186+
# Don't broadcast, since this message is already going out.
187+
self.set_state("starting", status, broadcast=False)
188+
else:
189+
parent = deserialized_msg.get("parent_header", {})
190+
msg_id = parent.get("msg_id", "")
191+
parent_channel = self.main_client.message_source_cache.get(msg_id, None)
192+
if parent_channel and parent_channel == "shell":
193+
# Don't broadcast, since this message is already going out.
194+
self.set_state("connected", status, broadcast=False)
195+

0 commit comments

Comments
 (0)