1515
1616from jupyter_client .manager import AsyncKernelManager
1717
18- from . import types
19- from . import states
18+ # from . import types
19+ from .states import ExecutionStates , LifecycleStates
2020from .kernel_client import AsyncKernelClient
2121
2222
@@ -30,27 +30,34 @@ class NextGenKernelManager(AsyncKernelManager):
3030
3131 client_factory : Type = Type (klass = "jupyter_rtc_core.kernels.kernel_client.NextGenAsyncKernelClient" )
3232
33- # Configurable settings in a kernel manager that I want.
34- time_to_connect : int = Int (
33+ connection_attempts : int = Int (
3534 default_value = 10 ,
36- help = "The timeout for connecting to a kernel."
35+ help = "The number of initial heartbeat attempts once the kernel is alive. Each attempt is 1 second apart ."
3736 ).tag (config = True )
3837
39- execution_state : types . EXECUTION_STATES = Unicode ()
38+ execution_state : ExecutionStates = Unicode ()
4039
4140 @validate ("execution_state" )
4241 def _validate_execution_state (self , proposal : dict ):
43- if not proposal ["value" ] in states .EXECUTION_STATES :
44- raise TraitError (f"execution_state must be one of { states .EXECUTION_STATES } " )
45- return proposal ["value" ]
42+ value = proposal ["value" ]
43+ if type (value ) == ExecutionStates :
44+ # Extract the enum value.
45+ value = value .value
46+ if not value in ExecutionStates :
47+ raise TraitError (f"execution_state must be one of { ExecutionStates } " )
48+ return value
4649
47- lifecycle_state : types . EXECUTION_STATES = Unicode ()
50+ lifecycle_state : LifecycleStates = Unicode ()
4851
4952 @validate ("lifecycle_state" )
5053 def _validate_lifecycle_state (self , proposal : dict ):
51- if not proposal ["value" ] in states .LIFECYCLE_STATES :
52- raise TraitError (f"lifecycle_state must be one of { states .LIFECYCLE_STATES } " )
53- return proposal ["value" ]
54+ value = proposal ["value" ]
55+ if type (value ) == LifecycleStates :
56+ # Extract the enum value.
57+ value = value .value
58+ if not value in LifecycleStates :
59+ raise TraitError (f"lifecycle_state must be one of { LifecycleStates } " )
60+ return value
5461
5562 state = Dict ()
5663
@@ -89,34 +96,34 @@ def _state_changed(self, change):
8996
9097 def set_state (
9198 self ,
92- lifecycle_state : typing . Optional [ types . LIFECYCLE_STATES ] = None ,
93- execution_state : typing . Optional [ types . EXECUTION_STATES ] = None ,
99+ lifecycle_state : LifecycleStates = None ,
100+ execution_state : ExecutionStates = None ,
94101 broadcast = True
95102 ):
96103 if lifecycle_state :
97- self .lifecycle_state = lifecycle_state
104+ self .lifecycle_state = lifecycle_state . value
98105 if execution_state :
99- self .execution_state = execution_state
106+ self .execution_state = execution_state . value
100107
101108 if broadcast :
102109 # Broadcast this state change to all listeners
103110 self .broadcast_state ()
104111
105112 async def start_kernel (self , * args , ** kwargs ):
106- self .set_state ("starting" , "starting" )
113+ self .set_state (LifecycleStates . STARTING , ExecutionStates . STARTING )
107114 out = await super ().start_kernel (* args , ** kwargs )
108- self .set_state ("started" )
115+ self .set_state (LifecycleStates . STARTED )
109116 await self .connect ()
110117 return out
111118
112119 async def shutdown_kernel (self , * args , ** kwargs ):
113- self .set_state ("terminating" )
120+ self .set_state (LifecycleStates . TERMINATING )
114121 await self .disconnect ()
115122 out = await super ().shutdown_kernel (* args , ** kwargs )
116- self .set_state ("terminated" , "dead" )
123+ self .set_state (LifecycleStates . TERMINATED , ExecutionStates . DEAD )
117124
118125 async def restart_kernel (self , * args , ** kwargs ):
119- self .set_state ("restarting" )
126+ self .set_state (LifecycleStates . RESTARTING )
120127 return await super ().restart_kernel (* args , ** kwargs )
121128
122129 async def connect (self ):
@@ -129,7 +136,7 @@ async def connect(self):
129136 be in a starting phase. We can keep a connection
130137 open regardless if the kernel is ready.
131138 """
132- self .set_state ("connecting" , "busy" )
139+ self .set_state (LifecycleStates . CONNECTING , ExecutionStates . BUSY )
133140 # Use the new API for getting a client.
134141 self .main_client = self .client ()
135142 # Track execution state by watching all messages that come through
@@ -143,15 +150,15 @@ async def connect(self):
143150 attempt = 0
144151 while not self .main_client .hb_channel .is_alive ():
145152 attempt += 1
146- if attempt > self .time_to_connect :
153+ if attempt > self .connection_attempts :
147154 # Set the state to unknown.
148- self .set_state ("unknown" , "unknown" )
155+ self .set_state (LifecycleStates . UNKNOWN , ExecutionStates . UNKNOWN )
149156 raise Exception ("The kernel took too long to connect to the ZMQ sockets." )
150157 # Wait a second until the next time we try again.
151158 await asyncio .sleep (1 )
152159 # Send an initial kernel info request on the shell channel.
153- self .main_client .kernel_info ()
154- self .set_state ("connected" )
160+ self .main_client .send_kernel_info ()
161+ self .set_state (LifecycleStates . CONNECTED )
155162
156163 async def disconnect (self ):
157164 await self .main_client .stop_listening ()
@@ -181,15 +188,15 @@ def execution_state_listener(self, channel_name, msg):
181188 deserialized_msg = session .deserialize (smsg , content = False )
182189 if deserialized_msg ["msg_type" ] == "status" :
183190 content = session .unpack (deserialized_msg ["content" ])
184- status = content ["execution_state" ]
185- if status == "starting" :
191+ execution_state = content ["execution_state" ]
192+ if execution_state == "starting" :
186193 # Don't broadcast, since this message is already going out.
187- self .set_state ("starting" , status , broadcast = False )
194+ self .set_state (LifecycleStates . STARTING , execution_state , broadcast = False )
188195 else :
189196 parent = deserialized_msg .get ("parent_header" , {})
190197 msg_id = parent .get ("msg_id" , "" )
191198 parent_channel = self .main_client .message_source_cache .get (msg_id , None )
192199 if parent_channel and parent_channel == "shell" :
193200 # Don't broadcast, since this message is already going out.
194- self .set_state ("connected" , status , broadcast = False )
201+ self .set_state (LifecycleStates . CONNECTED , execution_state , broadcast = False )
195202
0 commit comments