Skip to content

Commit 27f8a62

Browse files
authored
cleanup after PR review (#23)
* cleanup after PR review * Update jupyter_rtc_core/kernels/kernel_manager.py
1 parent 0c7b452 commit 27f8a62

File tree

4 files changed

+82
-67
lines changed

4 files changed

+82
-67
lines changed

jupyter_rtc_core/kernels/kernel_client.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from jupyter_client.asynchronous.client import AsyncKernelClient
99
import anyio
1010

11-
from jupyter_client.session import Session
1211

1312
class NextGenAsyncKernelClient(AsyncKernelClient):
1413
"""
@@ -23,7 +22,7 @@ class NextGenAsyncKernelClient(AsyncKernelClient):
2322
message_source_cache = Instance(
2423
default_value=LRUCache(maxsize=1000), klass=LRUCache
2524
)
26-
25+
2726
# A set of callables that are called when a
2827
# ZMQ message comes back from the kernel.
2928
_listeners = Set(allow_none=True)
@@ -44,15 +43,16 @@ async def _listening():
4443
# Background this task.
4544
self._listening_task = asyncio.create_task(_listening())
4645

46+
4747
async def stop_listening(self):
4848
# If the listening task isn't defined yet
4949
# do nothing.
5050
if not self._listening_task:
5151
return
5252

5353
# Attempt to cancel the task.
54-
self._listening_task.cancel()
5554
try:
55+
self._listening_task.cancel()
5656
# Await cancellation.
5757
await self._listening_task
5858
except asyncio.CancelledError:
@@ -86,13 +86,13 @@ async def recv_message(self, channel_name, msg):
8686
async with anyio.create_task_group() as tg:
8787
# Broadcast the message to all listeners.
8888
for listener in self._listeners:
89-
async def _wrap_listener(listener, channel_name, msg):
89+
async def _wrap_listener(listener_to_wrap, channel_name, msg):
9090
"""
9191
Wrap the listener to ensure its async and
9292
logs (instead of raises) exceptions.
9393
"""
9494
try:
95-
listener(channel_name, msg)
95+
listener_to_wrap(channel_name, msg)
9696
except Exception as err:
9797
self.log.error(err)
9898

@@ -109,7 +109,7 @@ def add_listener(self, callback: t.Callable[[dict], None]):
109109
self._listeners.add(callback)
110110

111111
def remove_listener(self, callback: t.Callable[[dict], None]):
112-
"""Remove a listener to teh ZMQ interface. If the listener
112+
"""Remove a listener to the ZMQ interface. If the listener
113113
is not found, this method does nothing.
114114
"""
115115
self._listeners.discard(callback)
@@ -130,7 +130,10 @@ async def _listen_for_messages(self, channel_name):
130130
except Exception as err:
131131
self.log.error(err)
132132

133-
def kernel_info(self):
133+
def send_kernel_info(self):
134+
"""Sends a kernel info message on the shell channel. Useful
135+
for determining if the kernel is busy or idle.
136+
"""
134137
msg = self.session.msg("kernel_info_request")
135138
# Send message, skipping the delimiter and signature
136139
msg = self.session.serialize(msg)[2:]

jupyter_rtc_core/kernels/kernel_manager.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
from jupyter_client.manager import AsyncKernelManager
1717

18-
from . import types
19-
from . import states
18+
# from . import types
19+
from .states import ExecutionStates, LifecycleStates
2020
from .kernel_client import AsyncKernelClient
2121

2222

@@ -30,27 +30,34 @@ class NextGenKernelManager(AsyncKernelManager):
3030

3131
client_factory: Type = Type(klass="jupyter_rtc_core.kernels.kernel_client.NextGenAsyncKernelClient")
3232

33-
# Configurable settings in a kernel manager that I want.
34-
time_to_connect: int = Int(
33+
connection_attempts: int = Int(
3534
default_value=10,
36-
help="The timeout for connecting to a kernel."
35+
help="The number of initial heartbeat attempts once the kernel is alive. Each attempt is 1 second apart."
3736
).tag(config=True)
3837

39-
execution_state: types.EXECUTION_STATES = Unicode()
38+
execution_state: ExecutionStates = Unicode()
4039

4140
@validate("execution_state")
4241
def _validate_execution_state(self, proposal: dict):
43-
if not proposal["value"] in states.EXECUTION_STATES:
44-
raise TraitError(f"execution_state must be one of {states.EXECUTION_STATES}")
45-
return proposal["value"]
42+
value = proposal["value"]
43+
if type(value) == ExecutionStates:
44+
# Extract the enum value.
45+
value = value.value
46+
if not value in ExecutionStates:
47+
raise TraitError(f"execution_state must be one of {ExecutionStates}")
48+
return value
4649

47-
lifecycle_state: types.EXECUTION_STATES = Unicode()
50+
lifecycle_state: LifecycleStates = Unicode()
4851

4952
@validate("lifecycle_state")
5053
def _validate_lifecycle_state(self, proposal: dict):
51-
if not proposal["value"] in states.LIFECYCLE_STATES:
52-
raise TraitError(f"lifecycle_state must be one of {states.LIFECYCLE_STATES}")
53-
return proposal["value"]
54+
value = proposal["value"]
55+
if type(value) == LifecycleStates:
56+
# Extract the enum value.
57+
value = value.value
58+
if not value in LifecycleStates:
59+
raise TraitError(f"lifecycle_state must be one of {LifecycleStates}")
60+
return value
5461

5562
state = Dict()
5663

@@ -89,34 +96,34 @@ def _state_changed(self, change):
8996

9097
def set_state(
9198
self,
92-
lifecycle_state: typing.Optional[types.LIFECYCLE_STATES] = None,
93-
execution_state: typing.Optional[types.EXECUTION_STATES] = None,
99+
lifecycle_state: LifecycleStates = None,
100+
execution_state: ExecutionStates = None,
94101
broadcast=True
95102
):
96103
if lifecycle_state:
97-
self.lifecycle_state = lifecycle_state
104+
self.lifecycle_state = lifecycle_state.value
98105
if execution_state:
99-
self.execution_state = execution_state
106+
self.execution_state = execution_state.value
100107

101108
if broadcast:
102109
# Broadcast this state change to all listeners
103110
self.broadcast_state()
104111

105112
async def start_kernel(self, *args, **kwargs):
106-
self.set_state("starting", "starting")
113+
self.set_state(LifecycleStates.STARTING, ExecutionStates.STARTING)
107114
out = await super().start_kernel(*args, **kwargs)
108-
self.set_state("started")
115+
self.set_state(LifecycleStates.STARTED)
109116
await self.connect()
110117
return out
111118

112119
async def shutdown_kernel(self, *args, **kwargs):
113-
self.set_state("terminating")
120+
self.set_state(LifecycleStates.TERMINATING)
114121
await self.disconnect()
115122
out = await super().shutdown_kernel(*args, **kwargs)
116-
self.set_state("terminated", "dead")
123+
self.set_state(LifecycleStates.TERMINATED, ExecutionStates.DEAD)
117124

118125
async def restart_kernel(self, *args, **kwargs):
119-
self.set_state("restarting")
126+
self.set_state(LifecycleStates.RESTARTING)
120127
return await super().restart_kernel(*args, **kwargs)
121128

122129
async def connect(self):
@@ -129,7 +136,7 @@ async def connect(self):
129136
be in a starting phase. We can keep a connection
130137
open regardless if the kernel is ready.
131138
"""
132-
self.set_state("connecting", "busy")
139+
self.set_state(LifecycleStates.CONNECTING, ExecutionStates.BUSY)
133140
# Use the new API for getting a client.
134141
self.main_client = self.client()
135142
# Track execution state by watching all messages that come through
@@ -143,15 +150,15 @@ async def connect(self):
143150
attempt = 0
144151
while not self.main_client.hb_channel.is_alive():
145152
attempt += 1
146-
if attempt > self.time_to_connect:
153+
if attempt > self.connection_attempts:
147154
# Set the state to unknown.
148-
self.set_state("unknown", "unknown")
155+
self.set_state(LifecycleStates.UNKNOWN, ExecutionStates.UNKNOWN)
149156
raise Exception("The kernel took too long to connect to the ZMQ sockets.")
150157
# Wait a second until the next time we try again.
151158
await asyncio.sleep(1)
152159
# Send an initial kernel info request on the shell channel.
153-
self.main_client.kernel_info()
154-
self.set_state("connected")
160+
self.main_client.send_kernel_info()
161+
self.set_state(LifecycleStates.CONNECTED)
155162

156163
async def disconnect(self):
157164
await self.main_client.stop_listening()
@@ -181,15 +188,15 @@ def execution_state_listener(self, channel_name, msg):
181188
deserialized_msg = session.deserialize(smsg, content=False)
182189
if deserialized_msg["msg_type"] == "status":
183190
content = session.unpack(deserialized_msg["content"])
184-
status = content["execution_state"]
185-
if status == "starting":
191+
execution_state = content["execution_state"]
192+
if execution_state == "starting":
186193
# Don't broadcast, since this message is already going out.
187-
self.set_state("starting", status, broadcast=False)
194+
self.set_state(LifecycleStates.STARTING, execution_state, broadcast=False)
188195
else:
189196
parent = deserialized_msg.get("parent_header", {})
190197
msg_id = parent.get("msg_id", "")
191198
parent_channel = self.main_client.message_source_cache.get(msg_id, None)
192199
if parent_channel and parent_channel == "shell":
193200
# Don't broadcast, since this message is already going out.
194-
self.set_state("connected", status, broadcast=False)
201+
self.set_state(LifecycleStates.CONNECTED, execution_state, broadcast=False)
195202

jupyter_rtc_core/kernels/states.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,35 @@
1-
import typing
2-
from . import types
1+
from enum import Enum
2+
from enum import EnumMeta
33

4-
EXECUTION_STATES: typing.Tuple[types.EXECUTION_STATES] = typing.get_args(types.EXECUTION_STATES)
5-
LIFECYCLE_STATES: typing.Tuple[types.LIFECYCLE_STATES] = typing.get_args(types.LIFECYCLE_STATES)
6-
LIFECYCLE_DEAD_STATES = ["dead", "disconnected", "terminated"]
4+
class StrContainerEnumMeta(EnumMeta):
5+
def __contains__(cls, item):
6+
for name, member in cls.__members__.items():
7+
if item == name or item == member.value:
8+
return True
9+
return False
10+
class StrContainerEnum(str, Enum, metaclass=StrContainerEnumMeta):
11+
"""A Enum object that enables search for items
12+
in a normal Enum object based on key and value.
13+
"""
14+
15+
class LifecycleStates(StrContainerEnum):
16+
UNKNOWN = "unknown"
17+
STARTING = "starting"
18+
STARTED = "started"
19+
TERMINATING = "terminating"
20+
CONNECTING = "connecting"
21+
CONNECTED = "connected"
22+
RESTARTING = "restarting"
23+
RECONNECTING = "reconnecting"
24+
CULLED = "culled"
25+
DISCONNECTED = "disconnected"
26+
TERMINATED = "terminated"
27+
DEAD = "dead"
28+
29+
30+
class ExecutionStates(StrContainerEnum):
31+
BUSY = "busy"
32+
IDLE = "idle"
33+
STARTING = "starting"
34+
UNKNOWN = "unknown"
35+
DEAD = "dead"

jupyter_rtc_core/kernels/types.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)