Skip to content

Commit f35aeec

Browse files
knareshZsailerNaresh Kumar Kolloju
authored
Fixes to make local Kernel work on Document aware session (#28)
* Provide a YDoc-aware session manager and kernel client. * Make kernels work with YDoc aware session managers --------- Co-authored-by: Zach Sailer <[email protected]> Co-authored-by: Naresh Kumar Kolloju <[email protected]>
1 parent 98482ff commit f35aeec

File tree

7 files changed

+298
-84
lines changed

7 files changed

+298
-84
lines changed

jupyter_rtc_core/app.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
from traitlets.config import Config
33
import asyncio
44

5+
from traitlets import Instance
6+
from traitlets import Type
57
from .handlers import RouteHandler
68
from .websockets import GlobalAwarenessWebsocket, YRoomWebsocket
7-
from .rooms import YRoomManager
9+
from .rooms.yroom_manager import YRoomManager
810

911
class RtcExtensionApp(ExtensionApp):
1012
name = "jupyter_rtc_core"
@@ -21,10 +23,22 @@ class RtcExtensionApp(ExtensionApp):
2123
# (r"api/collaboration/room/(.*)", YRoomWebsocket)
2224
]
2325

26+
yroom_manager_class = Type(
27+
klass=YRoomManager,
28+
help="""YRoom Manager Class.""",
29+
default_value=YRoomManager,
30+
).tag(config=True)
31+
32+
yroom_manager = Instance(
33+
klass=YRoomManager,
34+
help="An instance of the YRoom Manager.",
35+
allow_none=True
36+
).tag(config=True)
37+
38+
2439
def initialize(self):
2540
super().initialize()
2641

27-
2842
def initialize_settings(self):
2943
# Get YRoomManager arguments from server extension context.
3044
# We cannot access the 'file_id_manager' key immediately because server
@@ -50,5 +64,6 @@ def _link_jupyter_server_extension(self, server_app):
5064
c.ServerApp.kernel_websocket_connection_class = "jupyter_rtc_core.kernels.websocket_connection.NextGenKernelWebsocketConnection"
5165
c.ServerApp.kernel_manager_class = "jupyter_rtc_core.kernels.multi_kernel_manager.NextGenMappingKernelManager"
5266
c.MultiKernelManager.kernel_manager_class = "jupyter_rtc_core.kernels.kernel_manager.NextGenKernelManager"
67+
c.ServerApp.session_manager_class = "jupyter_rtc_core.session_manager.YDocSessionManager"
5368
server_app.update_config(c)
54-
super()._link_jupyter_server_extension(server_app)
69+
super()._link_jupyter_server_extension(server_app)
Lines changed: 144 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
A new Kernel client that is aware of ydocuments.
3+
"""
14
import asyncio
25
import json
36
import typing as t
@@ -7,26 +10,32 @@
710
from .utils import LRUCache
811
from jupyter_client.asynchronous.client import AsyncKernelClient
912
import anyio
13+
from jupyter_rtc_core.rooms.yroom import YRoom
1014

1115

12-
class NextGenAsyncKernelClient(AsyncKernelClient):
16+
class DocumentAwareKernelClient(AsyncKernelClient):
1317
"""
14-
A ZMQ-based kernel client class that managers all listeners to a kernel.
18+
A kernel client
1519
"""
16-
# Having this message cache is not ideal.
20+
# Having this message cache is not ideal.
1721
# Unfortunately, we don't include the parent channel
1822
# in the messages that generate IOPub status messages, thus,
1923
# we can't differential between the control channel vs.
20-
# shell channel status. This message cache gives us
24+
# shell channel status. This message cache gives us
2125
# the ability to map status message back to their source.
2226
message_source_cache = Instance(
2327
default_value=LRUCache(maxsize=1000), klass=LRUCache
2428
)
25-
26-
# A set of callables that are called when a
27-
# ZMQ message comes back from the kernel.
29+
30+
# A set of callables that are called when a kernel
31+
# message is received.
2832
_listeners = Set(allow_none=True)
2933

34+
# A set of YRooms that will intercept output and kernel
35+
# status messages.
36+
_yrooms: t.Set[YRoom] = Set(trait=Instance(YRoom), default_value=set())
37+
38+
3039
async def start_listening(self):
3140
"""Start listening to messages coming from the kernel.
3241
@@ -39,17 +48,18 @@ async def _listening():
3948
tg.start_soon(
4049
self._listen_for_messages, channel_name
4150
)
42-
51+
4352
# Background this task.
4453
self._listening_task = asyncio.create_task(_listening())
4554

46-
4755
async def stop_listening(self):
56+
"""Stop listening to the kernel.
57+
"""
4858
# If the listening task isn't defined yet
4959
# do nothing.
5060
if not self._listening_task:
5161
return
52-
62+
5363
# Attempt to cancel the task.
5464
try:
5565
self._listening_task.cancel()
@@ -60,10 +70,10 @@ async def stop_listening(self):
6070
# Log any exceptions that were raised.
6171
except Exception as err:
6272
self.log.error(err)
63-
73+
6474
_listening_task: t.Optional[t.Awaitable] = Any(allow_none=True)
6575

66-
def send_message(self, channel_name, msg):
76+
def handle_incoming_message(self, channel_name: str, msg: list[bytes]):
6777
"""Use the given session to send the message."""
6878
# Cache the message ID and its socket name so that
6979
# any response message can be mapped back to the
@@ -73,32 +83,17 @@ def send_message(self, channel_name, msg):
7383
self.message_source_cache[msg_id] = channel_name
7484
channel = getattr(self, f"{channel_name}_channel")
7585
channel.session.send_raw(channel.socket, msg)
76-
77-
async def recv_message(self, channel_name, msg):
78-
"""This is the main method that consumes every
79-
message coming back from the kernel. It parses the header
80-
(not the content, which might be large) and updates
81-
the last_activity, execution_state, and lifecycle_state
82-
when appropriate. Then, it routes the message
83-
to all listeners.
86+
87+
def send_kernel_info(self):
88+
"""Sends a kernel info message on the shell channel. Useful
89+
for determining if the kernel is busy or idle.
8490
"""
85-
# Broadcast messages
86-
async with anyio.create_task_group() as tg:
87-
# Broadcast the message to all listeners.
88-
for listener in self._listeners:
89-
async def _wrap_listener(listener_to_wrap, channel_name, msg):
90-
"""
91-
Wrap the listener to ensure its async and
92-
logs (instead of raises) exceptions.
93-
"""
94-
try:
95-
listener_to_wrap(channel_name, msg)
96-
except Exception as err:
97-
self.log.error(err)
98-
99-
tg.start_soon(_wrap_listener, listener, channel_name, msg)
91+
msg = self.session.msg("kernel_info_request")
92+
# Send message, skipping the delimiter and signature
93+
msg = self.session.serialize(msg)[2:]
94+
self.handle_incoming_message("shell", msg)
10095

101-
def add_listener(self, callback: t.Callable[[dict], None]):
96+
def add_listener(self, callback: t.Callable[[str, list[bytes]], None]):
10297
"""Add a listener to the ZMQ Interface.
10398
10499
A listener is a callable function/method that takes
@@ -108,13 +103,13 @@ def add_listener(self, callback: t.Callable[[dict], None]):
108103
"""
109104
self._listeners.add(callback)
110105

111-
def remove_listener(self, callback: t.Callable[[dict], None]):
112-
"""Remove a listener to the ZMQ interface. If the listener
106+
def remove_listener(self, callback: t.Callable[[str, list[bytes]], None]):
107+
"""Remove a listener. If the listener
113108
is not found, this method does nothing.
114109
"""
115110
self._listeners.discard(callback)
116111

117-
async def _listen_for_messages(self, channel_name):
112+
async def _listen_for_messages(self, channel_name: str):
118113
"""The basic polling loop for listened to kernel messages
119114
on a ZMQ socket.
120115
"""
@@ -125,16 +120,118 @@ async def _listen_for_messages(self, channel_name):
125120
# Wait for a message
126121
await channel.socket.poll(timeout=float("inf"))
127122
raw_msg = await channel.socket.recv_multipart()
123+
# Drop identities and delimit from the message parts.
124+
_, fed_msg_list = self.session.feed_identities(raw_msg)
125+
msg = fed_msg_list
128126
try:
129-
await self.recv_message(channel_name, raw_msg)
127+
await self.handle_outgoing_message(channel_name, msg)
130128
except Exception as err:
131129
self.log.error(err)
130+
131+
async def send_message_to_listeners(self, channel_name: str, msg: list[bytes]):
132+
"""
133+
Sends message to all registered listeners.
134+
"""
135+
async with anyio.create_task_group() as tg:
136+
# Broadcast the message to all listeners.
137+
for listener in self._listeners:
138+
async def _wrap_listener(listener_to_wrap, channel_name, msg):
139+
"""
140+
Wrap the listener to ensure its async and
141+
logs (instead of raises) exceptions.
142+
"""
143+
try:
144+
listener_to_wrap(channel_name, msg)
145+
except Exception as err:
146+
self.log.error(err)
147+
148+
tg.start_soon(_wrap_listener, listener, channel_name, msg)
149+
150+
async def handle_outgoing_message(self, channel_name: str, msg: list[bytes]):
151+
"""This is the main method that consumes every
152+
message coming back from the kernel. It parses the header
153+
(not the content, which might be large) and updates
154+
the last_activity, execution_state, and lifecycsle_state
155+
when appropriate. Then, it routes the message
156+
to all listeners.
157+
"""
158+
159+
# Intercept messages that are IOPub focused.
160+
if channel_name == "iopub":
161+
message_returned = await self.handle_iopub_message(msg)
162+
# TODO: If the message is not returned by the iopub handler, then
163+
# return here and do not forward to listeners.
164+
if not message_returned:
165+
self.log.warn(f"If message is handled donot forward after adding output manager")
166+
return
167+
168+
# Update the last activity.
169+
#self.last_activity = self.session.msg_time
170+
171+
await self.send_message_to_listeners(channel_name, msg)
172+
173+
async def handle_iopub_message(self, msg: list[bytes]) -> t.Optional[list[bytes]]:
174+
"""
175+
Handle messages
176+
177+
Parameters
178+
----------
179+
dmsg: dict
180+
Deserialized message (except concept)
181+
182+
Returns
183+
-------
184+
Returns the message if it should be forwarded to listeners. Otherwise,
185+
returns `None` and keeps (i.e. intercepts) the message from going
186+
to listenres.
187+
"""
188+
# NOTE: Here's where we will inject the kernel state
189+
# into the awareness of a document.
190+
191+
try:
192+
dmsg = self.session.deserialize(msg, content=False)
193+
except Exception as e:
194+
self.log.error(f"Error deserializing message: {e}")
195+
raise ValueError
132196

133-
def send_kernel_info(self):
134-
"""Sends a kernel info message on the shell channel. Useful
135-
for determining if the kernel is busy or idle.
197+
if dmsg["msg_type"] == "status":
198+
# Forward to all yrooms.
199+
for yroom in self._yrooms:
200+
# NOTE: We need to create a real message here.
201+
awareness_update_message = b""
202+
self.log.debug(f"Update Awareness here: {dmsg}. YRoom: {yroom}")
203+
#self.log.debug(f"Getting YDoc: {await yroom.get_ydoc()}")
204+
#yroom.add_message(awareness_update_message)
205+
206+
# TODO: returning message temporarily to not break UI
207+
return msg
208+
209+
210+
# NOTE: Inject display data into ydoc.
211+
if dmsg["msg_type"] == "display_data":
212+
# Forward to all yrooms.
213+
for yroom in self._yrooms:
214+
update_document_message = b""
215+
self.log.debug(f"Update Document here: {dmsg}. Yroom: {yroom}")
216+
#self.log.debug(f"Getting YDoc: {await yroom.get_ydoc()}")
217+
#yroom.add_message(update_document_message)
218+
219+
# TODO: returning message temporarily to not break UI
220+
return msg
221+
222+
# If the message isn't handled above, return it and it will
223+
# be forwarded to all listeners
224+
return msg
225+
226+
async def add_yroom(self, yroom: YRoom):
136227
"""
137-
msg = self.session.msg("kernel_info_request")
138-
# Send message, skipping the delimiter and signature
139-
msg = self.session.serialize(msg)[2:]
140-
self.send_message("shell", msg)
228+
Register a YRoom with this kernel client. YRooms will
229+
intercept display and kernel status messages.
230+
"""
231+
self._yrooms.add(yroom)
232+
233+
async def remove_yroom(self, yroom: YRoom):
234+
"""
235+
De-register a YRoom from handling kernel client messages.
236+
"""
237+
self._yrooms.discard(yroom)

jupyter_rtc_core/kernels/kernel_manager.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121

2222

2323
class NextGenKernelManager(AsyncKernelManager):
24-
24+
2525
main_client = Instance(AsyncKernelClient, allow_none=True)
2626

2727
client_class = DottedObjectName(
28-
"jupyter_rtc_core.kernels.kernel_client.NextGenAsyncKernelClient"
28+
"jupyter_rtc_core.kernels.kernel_client.DocumentAwareKernelClient"
2929
)
3030

31-
client_factory: Type = Type(klass="jupyter_rtc_core.kernels.kernel_client.NextGenAsyncKernelClient")
31+
client_factory: Type = Type(klass="jupyter_rtc_core.kernels.kernel_client.DocumentAwareKernelClient")
3232

3333
connection_attempts: int = Int(
3434
default_value=10,
@@ -107,6 +107,8 @@ def set_state(
107107

108108
if broadcast:
109109
# Broadcast this state change to all listeners
110+
# Turn off state broadcasting temporarily to avoid
111+
self._state_observers = None
110112
self.broadcast_state()
111113

112114
async def start_kernel(self, *args, **kwargs):
@@ -175,28 +177,29 @@ def broadcast_state(self):
175177
session = self.main_client.session
176178
msg = session.msg("status", {"execution_state": self.execution_state})
177179
msg = session.serialize(msg)
178-
listener("iopub", msg)
180+
_, fed_msg_list = self.session.feed_identities(msg)
181+
listener("iopub", fed_msg_list)
179182

180-
def execution_state_listener(self, channel_name, msg):
183+
def execution_state_listener(self, channel_name: str, msg: list[bytes]):
181184
"""Set the execution state by watching messages returned by the shell channel."""
182185
# Only continue if we're on the IOPub where the status is published.
183186
if channel_name != "iopub":
184187
return
188+
185189
session = self.main_client.session
186-
_, smsg = session.feed_identities(msg)
187190
# Unpack the message
188-
deserialized_msg = session.deserialize(smsg, content=False)
191+
deserialized_msg = session.deserialize(msg, content=False)
189192
if deserialized_msg["msg_type"] == "status":
190193
content = session.unpack(deserialized_msg["content"])
191194
execution_state = content["execution_state"]
192195
if execution_state == "starting":
193196
# Don't broadcast, since this message is already going out.
194-
self.set_state(LifecycleStates.STARTING, execution_state, broadcast=False)
197+
self.set_state(LifecycleStates.STARTING, ExecutionStates.STARTING, broadcast=False)
195198
else:
196199
parent = deserialized_msg.get("parent_header", {})
197200
msg_id = parent.get("msg_id", "")
198201
parent_channel = self.main_client.message_source_cache.get(msg_id, None)
199202
if parent_channel and parent_channel == "shell":
200203
# Don't broadcast, since this message is already going out.
201-
self.set_state(LifecycleStates.CONNECTED, execution_state, broadcast=False)
204+
self.set_state(LifecycleStates.CONNECTED, ExecutionStates(execution_state), broadcast=False)
202205

0 commit comments

Comments
 (0)