1+ import typing
2+ import asyncio
3+ from traitlets import default
4+ from traitlets import Instance
5+ from traitlets import Int
6+ from traitlets import Dict
7+ from traitlets import Type
8+ from traitlets import Unicode
9+ from traitlets import validate
10+ from traitlets import observe
11+ from traitlets import Set
12+ from traitlets import TraitError
13+ from traitlets import DottedObjectName
14+ from traitlets .utils .importstring import import_item
15+
16+ from jupyter_client .manager import AsyncKernelManager
17+
18+ from . import types
19+ from . import states
20+ from .kernel_client import AsyncKernelClient
21+
22+
23+ class NextGenKernelManager (AsyncKernelManager ):
24+
25+ main_client = Instance (AsyncKernelClient , allow_none = True )
26+
27+ client_class = DottedObjectName (
28+ "jupyter_rtc_core.kernels.kernel_client.NextGenAsyncKernelClient"
29+ )
30+
31+ client_factory : Type = Type (klass = "jupyter_rtc_core.kernels.kernel_client.NextGenAsyncKernelClient" )
32+
33+ # Configurable settings in a kernel manager that I want.
34+ time_to_connect : int = Int (
35+ default_value = 10 ,
36+ help = "The timeout for connecting to a kernel."
37+ ).tag (config = True )
38+
39+ execution_state : types .EXECUTION_STATES = Unicode ()
40+
41+ @validate ("execution_state" )
42+ def _validate_execution_state (self , proposal : dict ):
43+ if not proposal ["value" ] in states .EXECUTION_STATES :
44+ raise TraitError (f"execution_state must be one of { states .EXECUTION_STATES } " )
45+ return proposal ["value" ]
46+
47+ lifecycle_state : types .EXECUTION_STATES = Unicode ()
48+
49+ @validate ("lifecycle_state" )
50+ def _validate_lifecycle_state (self , proposal : dict ):
51+ if not proposal ["value" ] in states .LIFECYCLE_STATES :
52+ raise TraitError (f"lifecycle_state must be one of { states .LIFECYCLE_STATES } " )
53+ return proposal ["value" ]
54+
55+ state = Dict ()
56+
57+ @default ('state' )
58+ def _default_state (self ):
59+ return {
60+ "execution_state" : self .execution_state ,
61+ "lifecycle_state" : self .lifecycle_state
62+ }
63+
64+ @observe ('execution_state' )
65+ def _observer_execution_state (self , change ):
66+ state = self .state
67+ state ["execution_state" ] = change ['new' ]
68+ self .state = state
69+
70+ @observe ('lifecycle_state' )
71+ def _observer_lifecycle_state (self , change ):
72+ state = self .state
73+ state ["lifecycle_state" ] = change ['new' ]
74+ self .state = state
75+
76+ @validate ('state' )
77+ def _validate_state (self , change ):
78+ value = change ['value' ]
79+ if 'execution_state' not in value or 'lifecycle_state' not in value :
80+ TraitError ("State needs to include execution_state and lifecycle_state" )
81+ return value
82+
83+ @observe ('state' )
84+ def _state_changed (self , change ):
85+ for observer in self ._state_observers :
86+ observer (change ["new" ])
87+
88+ _state_observers = Set (allow_none = True )
89+
90+ def set_state (
91+ self ,
92+ lifecycle_state : typing .Optional [types .LIFECYCLE_STATES ] = None ,
93+ execution_state : typing .Optional [types .EXECUTION_STATES ] = None ,
94+ broadcast = True
95+ ):
96+ if lifecycle_state :
97+ self .lifecycle_state = lifecycle_state
98+ if execution_state :
99+ self .execution_state = execution_state
100+
101+ if broadcast :
102+ # Broadcast this state change to all listeners
103+ self .broadcast_state ()
104+
105+ async def start_kernel (self , * args , ** kwargs ):
106+ self .set_state ("starting" , "starting" )
107+ out = await super ().start_kernel (* args , ** kwargs )
108+ self .set_state ("started" )
109+ await self .connect ()
110+ return out
111+
112+ async def shutdown_kernel (self , * args , ** kwargs ):
113+ self .set_state ("terminating" )
114+ await self .disconnect ()
115+ out = await super ().shutdown_kernel (* args , ** kwargs )
116+ self .set_state ("terminated" , "dead" )
117+
118+ async def restart_kernel (self , * args , ** kwargs ):
119+ self .set_state ("restarting" )
120+ return await super ().restart_kernel (* args , ** kwargs )
121+
122+ async def connect (self ):
123+ """Open a single client interface to the kernel.
124+
125+ Ideally this method doesn't care if the kernel
126+ is actually started. It will just try a ZMQ
127+ connection anyways and wait. This is helpful for
128+ handling 'pending' kernels, which might still
129+ be in a starting phase. We can keep a connection
130+ open regardless if the kernel is ready.
131+ """
132+ self .set_state ("connecting" , "busy" )
133+ # Use the new API for getting a client.
134+ self .main_client = self .client ()
135+ # Track execution state by watching all messages that come through
136+ # the kernel client.
137+ self .main_client .add_listener (self .execution_state_listener )
138+ self .main_client .start_channels ()
139+ await self .main_client .start_listening ()
140+ # The Heartbeat channel is paused by default; unpause it here
141+ self .main_client .hb_channel .unpause ()
142+ # Wait for a living heartbeat.
143+ attempt = 0
144+ while not self .main_client .hb_channel .is_alive ():
145+ attempt += 1
146+ if attempt > self .time_to_connect :
147+ # Set the state to unknown.
148+ self .set_state ("unknown" , "unknown" )
149+ raise Exception ("The kernel took too long to connect to the ZMQ sockets." )
150+ # Wait a second until the next time we try again.
151+ await asyncio .sleep (1 )
152+ # Send an initial kernel info request on the shell channel.
153+ self .main_client .kernel_info ()
154+ self .set_state ("connected" )
155+
156+ async def disconnect (self ):
157+ await self .main_client .stop_listening ()
158+ self .main_client .stop_channels ()
159+
160+ def broadcast_state (self ):
161+ """Broadcast state to all listeners"""
162+ if not self .main_client :
163+ return
164+
165+ # Emit this state to all listeners
166+ for listener in self .main_client ._listeners :
167+ # Manufacture a status message
168+ session = self .main_client .session
169+ msg = session .msg ("status" , {"execution_state" : self .execution_state })
170+ msg = session .serialize (msg )
171+ listener ("iopub" , msg )
172+
173+ def execution_state_listener (self , channel_name , msg ):
174+ """Set the execution state by watching messages returned by the shell channel."""
175+ # Only continue if we're on the IOPub where the status is published.
176+ if channel_name != "iopub" :
177+ return
178+ session = self .main_client .session
179+ _ , smsg = session .feed_identities (msg )
180+ # Unpack the message
181+ deserialized_msg = session .deserialize (smsg , content = False )
182+ if deserialized_msg ["msg_type" ] == "status" :
183+ content = session .unpack (deserialized_msg ["content" ])
184+ status = content ["execution_state" ]
185+ if status == "starting" :
186+ # Don't broadcast, since this message is already going out.
187+ self .set_state ("starting" , status , broadcast = False )
188+ else :
189+ parent = deserialized_msg .get ("parent_header" , {})
190+ msg_id = parent .get ("msg_id" , "" )
191+ parent_channel = self .main_client .message_source_cache .get (msg_id , None )
192+ if parent_channel and parent_channel == "shell" :
193+ # Don't broadcast, since this message is already going out.
194+ self .set_state ("connected" , status , broadcast = False )
195+
0 commit comments