Skip to content

Commit 7a0a5b2

Browse files
committed
wire up backend to kernel execution state
1 parent 60cb682 commit 7a0a5b2

File tree

6 files changed

+51
-105
lines changed

6 files changed

+51
-105
lines changed

jupyter_rtc_core/kernels/kernel_client.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from jupyter_client.asynchronous.client import AsyncKernelClient
1212
import anyio
1313
from jupyter_rtc_core.rooms.yroom import YRoom
14-
14+
from jupyter_server.utils import ensure_async
1515

1616
class DocumentAwareKernelClient(AsyncKernelClient):
1717
"""
@@ -140,7 +140,7 @@ async def _wrap_listener(listener_to_wrap, channel_name, msg):
140140
logs (instead of raises) exceptions.
141141
"""
142142
try:
143-
listener_to_wrap(channel_name, msg)
143+
await ensure_async(listener_to_wrap(channel_name, msg))
144144
except Exception as err:
145145
self.log.error(err)
146146

@@ -154,19 +154,17 @@ async def handle_outgoing_message(self, channel_name: str, msg: list[bytes]):
154154
when appropriate. Then, it routes the message
155155
to all listeners.
156156
"""
157-
158157
# Intercept messages that are IOPub focused.
159158
if channel_name == "iopub":
160159
message_returned = await self.handle_iopub_message(msg)
161160
# TODO: If the message is not returned by the iopub handler, then
162161
# return here and do not forward to listeners.
163162
if not message_returned:
164-
self.log.warn(f"If message is handled donot forward after adding output manager")
163+
self.log.warn(f"If message is handled do not forward after adding output manager")
165164
return
166165

167166
# Update the last activity.
168167
# self.last_activity = self.session.msg_time
169-
170168
await self.send_message_to_listeners(channel_name, msg)
171169

172170
async def handle_iopub_message(self, msg: list[bytes]) -> t.Optional[list[bytes]]:

jupyter_rtc_core/kernels/kernel_manager.py

Lines changed: 27 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -59,62 +59,25 @@ def _validate_lifecycle_state(self, proposal: dict):
5959
raise TraitError(f"lifecycle_state must be one of {LifecycleStates}")
6060
return value
6161

62-
state = Dict()
63-
64-
@default('state')
65-
def _default_state(self):
66-
return {
67-
"execution_state": self.execution_state,
68-
"lifecycle_state": self.lifecycle_state
69-
}
70-
71-
@observe('execution_state')
72-
def _observer_execution_state(self, change):
73-
state = self.state
74-
state["execution_state"] = change['new']
75-
self.state = state
76-
77-
@observe('lifecycle_state')
78-
def _observer_lifecycle_state(self, change):
79-
state = self.state
80-
state["lifecycle_state"] = change['new']
81-
self.state = state
82-
83-
@validate('state')
84-
def _validate_state(self, change):
85-
value = change['value']
86-
if 'execution_state' not in value or 'lifecycle_state' not in value:
87-
TraitError("State needs to include execution_state and lifecycle_state")
88-
return value
89-
90-
@observe('state')
91-
def _state_changed(self, change):
92-
for observer in self._state_observers:
93-
observer(change["new"])
94-
95-
_state_observers = Set(allow_none=True)
96-
9762
def set_state(
9863
self,
9964
lifecycle_state: LifecycleStates = None,
10065
execution_state: ExecutionStates = None,
101-
broadcast=True
10266
):
10367
if lifecycle_state:
10468
self.lifecycle_state = lifecycle_state.value
10569
if execution_state:
10670
self.execution_state = execution_state.value
107-
108-
if broadcast:
109-
# Broadcast this state change to all listeners
110-
self._state_observers = None
111-
self.broadcast_state()
11271

11372
async def start_kernel(self, *args, **kwargs):
11473
self.set_state(LifecycleStates.STARTING, ExecutionStates.STARTING)
11574
out = await super().start_kernel(*args, **kwargs)
11675
self.set_state(LifecycleStates.STARTED)
117-
await self.connect()
76+
# Schedule the kernel to connect.
77+
# Do not await here, since many clients expect
78+
# the server to complete the start flow even
79+
# if the kernel is not fully connected yet.
80+
task = asyncio.create_task(self.connect())
11881
return out
11982

12083
async def shutdown_kernel(self, *args, **kwargs):
@@ -137,12 +100,13 @@ async def connect(self):
137100
be in a starting phase. We can keep a connection
138101
open regardless if the kernel is ready.
139102
"""
140-
self.set_state(LifecycleStates.CONNECTING, ExecutionStates.BUSY)
141103
# Use the new API for getting a client.
142104
self.main_client = self.client()
143105
# Track execution state by watching all messages that come through
144106
# the kernel client.
145107
self.main_client.add_listener(self.execution_state_listener)
108+
self.set_state(LifecycleStates.CONNECTING, ExecutionStates.STARTING)
109+
await self.broadcast_state()
146110
self.main_client.start_channels()
147111
await self.main_client.start_listening()
148112
# The Heartbeat channel is paused by default; unpause it here
@@ -157,33 +121,34 @@ async def connect(self):
157121
raise Exception("The kernel took too long to connect to the ZMQ sockets.")
158122
# Wait a second until the next time we try again.
159123
await asyncio.sleep(1)
160-
# Send an initial kernel info request on the shell channel.
161-
self.main_client.send_kernel_info()
162-
self.set_state(LifecycleStates.CONNECTED)
124+
# Wait for the kernel to reach an idle state.
125+
while self.execution_state != ExecutionStates.IDLE.value:
126+
self.main_client.send_kernel_info()
127+
await asyncio.sleep(1)
163128

164129
async def disconnect(self):
165130
await self.main_client.stop_listening()
166131
self.main_client.stop_channels()
167132

168-
def broadcast_state(self):
133+
async def broadcast_state(self):
169134
"""Broadcast state to all listeners"""
170135
if not self.main_client:
171136
return
172-
173-
# Emit this state to all listeners
174-
for listener in self.main_client._listeners:
175-
# Manufacture a status message
176-
session = self.main_client.session
177-
msg = session.msg("status", {"execution_state": self.execution_state})
178-
msg = session.serialize(msg)
179-
_, fed_msg_list = self.session.feed_identities(msg)
180-
listener("iopub", fed_msg_list)
137+
138+
# Manufacture an IOPub status message from the shell channel.
139+
session = self.main_client.session
140+
parent_header = session.msg_header("status")
141+
parent_msg_id = parent_header["msg_id"]
142+
self.main_client.message_source_cache[parent_msg_id] = "shell"
143+
msg = session.msg("status", content={"execution_state": self.execution_state}, parent=parent_header)
144+
smsg = session.serialize(msg)[1:]
145+
await self.main_client.handle_outgoing_message("iopub", smsg)
181146

182147
def execution_state_listener(self, channel_name: str, msg: list[bytes]):
183148
"""Set the execution state by watching messages returned by the shell channel."""
184149
# Only continue if we're on the IOPub where the status is published.
185150
if channel_name != "iopub":
186-
return
151+
return
187152

188153
session = self.main_client.session
189154
# Unpack the message
@@ -193,22 +158,18 @@ def execution_state_listener(self, channel_name: str, msg: list[bytes]):
193158
execution_state = content["execution_state"]
194159
if execution_state == "starting":
195160
# Don't broadcast, since this message is already going out.
196-
self.set_state(LifecycleStates.STARTING, ExecutionStates.STARTING, broadcast=False)
161+
self.set_state(execution_state=ExecutionStates.STARTING)
197162
else:
198163
parent = deserialized_msg.get("parent_header", {})
199164
msg_id = parent.get("msg_id", "")
200165
parent_channel = self.main_client.message_source_cache.get(msg_id, None)
201166
if parent_channel and parent_channel == "shell":
202-
# Don't broadcast, since this message is already going out.
203-
self.set_state(LifecycleStates.CONNECTED, ExecutionStates(execution_state), broadcast=False)
204-
167+
self.set_state(LifecycleStates.CONNECTED, ExecutionStates(execution_state))
168+
205169
kernel_status = {
206170
"execution_state": self.execution_state,
207171
"lifecycle_state": self.lifecycle_state
208-
}
172+
}
209173
self.log.debug(f"Sending kernel status awareness {kernel_status}")
210174
self.main_client.send_kernel_awareness(kernel_status)
211-
self.log.debug(f"Sent kernel status awareness {kernel_status}")
212-
213-
214-
175+
self.log.debug(f"Sent kernel status awareness {kernel_status}")

jupyter_rtc_core/kernels/websocket_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ async def connect(self):
2020
asyncio task happening in parallel.
2121
"""
2222
self.kernel_manager.main_client.add_listener(self.handle_outgoing_message)
23-
self.kernel_manager.broadcast_state()
23+
await self.kernel_manager.broadcast_state()
2424
self.log.info("Kernel websocket is now listening to kernel.")
2525

2626
def disconnect(self):

jupyter_rtc_core/rooms/yroom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def send_server_awareness(self, type: str, changes: tuple[dict[str, Any], Any])
384384
Arguments:
385385
type: The change type.
386386
changes: The awareness changes.
387-
"""
387+
"""
388388
if type != "update" or changes[1] != "local":
389389
return
390390

src/executionindicator.tsx

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ import {
1313
ExecutionIndicatorComponent,
1414
ExecutionIndicator as E,
1515
} from '@jupyterlab/notebook';
16-
import { Poll } from '@lumino/polling';
17-
18-
19-
const EXECUTION_STATE_KEY = "executionState";
2016

2117

2218
/**
@@ -77,22 +73,7 @@ export namespace AwarenessExecutionIndicator {
7773
if (!nb) {
7874
return;
7975
}
80-
(this as any)._currentNotebook = nb;
81-
// Artificially toggle the execution state.
82-
new Poll({
83-
auto: true,
84-
factory: () => {
85-
let fullState = nb?.model?.sharedModel.awareness.getLocalState();
86-
let state = 'idle'
87-
if (fullState && fullState[EXECUTION_STATE_KEY] === "idle") {
88-
state = 'busy'
89-
}
90-
nb?.model?.sharedModel.awareness.setLocalState({ executionState: state})
91-
return Promise.resolve()
92-
},
93-
frequency: { interval: 2000, backoff: false }
94-
});
95-
76+
(this as any)._currentNotebook = nb;
9677
(this as any)._notebookExecutionProgress.set(nb, {
9778
executionStatus: 'idle',
9879
kernelStatus: 'idle',
@@ -107,13 +88,17 @@ export namespace AwarenessExecutionIndicator {
10788

10889
const contextStatusChanged = (ctx: ISessionContext) => {
10990
if (state) {
110-
let fullState = nb?.model?.sharedModel.awareness.getLocalState();
111-
if (fullState) {
112-
let currentState = fullState[EXECUTION_STATE_KEY];
113-
state.kernelStatus = currentState;
91+
let awarenessStates = nb?.model?.sharedModel.awareness.getStates();
92+
if (awarenessStates) {
93+
for (let [_, clientState] of awarenessStates) {
94+
if ('kernel' in clientState) {
95+
state.kernelStatus = clientState['kernel']['execution_state'];
96+
this.stateChanged.emit(void 0);
97+
return;
98+
}
99+
}
114100
}
115101
}
116-
this.stateChanged.emit(void 0);
117102
};
118103

119104
nb?.model?.sharedModel.awareness.on('change', contextStatusChanged);

src/kernelstatus.tsx

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ import {
1616
import { NotebookPanel } from '@jupyterlab/notebook';
1717
import { Session } from '@jupyterlab/services';
1818

19-
const EXECUTION_STATE_KEY = "executionState";
20-
2119
/**
2220
* A pure functional component for rendering kernel status.
2321
*/
@@ -133,11 +131,15 @@ export namespace AwarenessKernelStatus {
133131
let panel = (widget as NotebookPanel);
134132
const stateChanged = () => {
135133
if (this) {
136-
let fullState = panel?.model?.sharedModel.awareness.getLocalState();
137-
if (fullState) {
138-
let currentState = fullState[EXECUTION_STATE_KEY];
139-
(this as any)._kernelStatus = currentState;
140-
(this as any).stateChanged.emit(void 0);
134+
let awarenessStates = panel?.model?.sharedModel.awareness.getStates();
135+
if (awarenessStates) {
136+
for (let [_, clientState] of awarenessStates) {
137+
if ('kernel' in clientState) {
138+
(this as any)._kernelStatus = clientState['kernel']['execution_state'];
139+
(this as any).stateChanged.emit(void 0);
140+
return;
141+
}
142+
}
141143
}
142144
}
143145
};

0 commit comments

Comments
 (0)