Skip to content

Commit 57a7db7

Browse files
committed
Add async API
1 parent 650fd75 commit 57a7db7

File tree

5 files changed

+695
-1
lines changed

5 files changed

+695
-1
lines changed

jupyter_client/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
from .client import KernelClient
77
from .manager import KernelManager, run_kernel
88
from .blocking import BlockingKernelClient
9+
from .asynchronous import AsyncKernelClient
910
from .multikernelmanager import MultiKernelManager
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .client import AsyncKernelClient
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
"""Base classes to manage a Client's interaction with a running kernel"""
2+
3+
# Copyright (c) Jupyter Development Team.
4+
# Distributed under the terms of the Modified BSD License.
5+
6+
from __future__ import absolute_import
7+
8+
import atexit
9+
import errno
10+
from threading import Thread, Event
11+
import time
12+
import asyncio
13+
14+
import zmq
15+
# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
16+
# during garbage collection of threads at exit:
17+
from zmq import ZMQError
18+
19+
from jupyter_client import protocol_version_info
20+
21+
from ..channelsabc import HBChannelABC
22+
23+
try:
24+
from queue import Queue, Empty # Py 3
25+
except ImportError:
26+
from Queue import Queue, Empty # Py 2
27+
28+
29+
class ZMQSocketChannel(object):
30+
"""A ZMQ socket in a simple async API"""
31+
session = None
32+
socket = None
33+
stream = None
34+
_exiting = False
35+
proxy_methods = []
36+
37+
def __init__(self, socket, session, loop=None):
38+
"""Create a channel.
39+
40+
Parameters
41+
----------
42+
socket : :class:`zmq.asyncio.Socket`
43+
The ZMQ socket to use.
44+
session : :class:`session.Session`
45+
The session to use.
46+
loop
47+
Unused here, for other implementations
48+
"""
49+
super(ZMQSocketChannel, self).__init__()
50+
51+
self.socket = socket
52+
self.session = session
53+
54+
async def _recv(self, **kwargs):
55+
msg = await self.socket.recv_multipart(**kwargs)
56+
ident,smsg = self.session.feed_identities(msg)
57+
return self.session.deserialize(smsg)
58+
59+
async def get_msg(self, timeout=None):
60+
""" Gets a message if there is one that is ready. """
61+
if timeout is not None:
62+
timeout *= 1000 # seconds to ms
63+
ready = await self.socket.poll(timeout)
64+
65+
if ready:
66+
return await self._recv()
67+
else:
68+
raise Empty
69+
70+
async def get_msgs(self):
71+
""" Get all messages that are currently ready. """
72+
msgs = []
73+
while True:
74+
try:
75+
msgs.append(await self.get_msg())
76+
except Empty:
77+
break
78+
return msgs
79+
80+
async def msg_ready(self):
81+
""" Is there a message that has been received? """
82+
return bool(await self.socket.poll(timeout=0))
83+
84+
def close(self):
85+
if self.socket is not None:
86+
try:
87+
self.socket.close(linger=0)
88+
except Exception:
89+
pass
90+
self.socket = None
91+
stop = close
92+
93+
def is_alive(self):
94+
return (self.socket is not None)
95+
96+
def send(self, msg):
97+
"""Pass a message to the ZMQ socket to send
98+
"""
99+
self.session.send(self.socket, msg)
100+
101+
def start(self):
102+
pass
103+
#-----------------------------------------------------------------------------
104+
# Constants and exceptions
105+
#-----------------------------------------------------------------------------
106+
107+
major_protocol_version = protocol_version_info[0]
108+
109+
class InvalidPortNumber(Exception):
110+
pass
111+
112+
class HBChannel(Thread):
113+
"""The heartbeat channel which monitors the kernel heartbeat.
114+
115+
Note that the heartbeat channel is paused by default. As long as you start
116+
this channel, the kernel manager will ensure that it is paused and un-paused
117+
as appropriate.
118+
"""
119+
context = None
120+
session = None
121+
socket = None
122+
address = None
123+
_exiting = False
124+
125+
time_to_dead = 1.
126+
poller = None
127+
_running = None
128+
_pause = None
129+
_beating = None
130+
131+
def __init__(self, context=None, session=None, address=None, loop=None):
132+
"""Create the heartbeat monitor thread.
133+
134+
Parameters
135+
----------
136+
context : :class:`zmq.Context`
137+
The ZMQ context to use.
138+
session : :class:`session.Session`
139+
The session to use.
140+
address : zmq url
141+
Standard (ip, port) tuple that the kernel is listening on.
142+
"""
143+
super(HBChannel, self).__init__()
144+
self.daemon = True
145+
146+
self.loop = loop
147+
148+
self.context = context
149+
self.session = session
150+
if isinstance(address, tuple):
151+
if address[1] == 0:
152+
message = 'The port number for a channel cannot be 0.'
153+
raise InvalidPortNumber(message)
154+
address = "tcp://%s:%i" % address
155+
self.address = address
156+
157+
# running is False until `.start()` is called
158+
self._running = False
159+
self._exit = Event()
160+
# don't start paused
161+
self._pause = False
162+
self.poller = zmq.Poller()
163+
164+
@staticmethod
165+
@atexit.register
166+
def _notice_exit():
167+
# Class definitions can be torn down during interpreter shutdown.
168+
# We only need to set _exiting flag if this hasn't happened.
169+
if HBChannel is not None:
170+
HBChannel._exiting = True
171+
172+
def _create_socket(self):
173+
if self.socket is not None:
174+
# close previous socket, before opening a new one
175+
self.poller.unregister(self.socket)
176+
self.socket.close()
177+
self.socket = self.context.socket(zmq.REQ)
178+
self.socket.linger = 1000
179+
self.socket.connect(self.address)
180+
181+
self.poller.register(self.socket, zmq.POLLIN)
182+
183+
def _poll(self, start_time):
184+
"""poll for heartbeat replies until we reach self.time_to_dead.
185+
186+
Ignores interrupts, and returns the result of poll(), which
187+
will be an empty list if no messages arrived before the timeout,
188+
or the event tuple if there is a message to receive.
189+
"""
190+
191+
until_dead = self.time_to_dead - (time.time() - start_time)
192+
# ensure poll at least once
193+
until_dead = max(until_dead, 1e-3)
194+
events = []
195+
while True:
196+
try:
197+
events = self.poller.poll(1000 * until_dead)
198+
except ZMQError as e:
199+
if e.errno == errno.EINTR:
200+
# ignore interrupts during heartbeat
201+
# this may never actually happen
202+
until_dead = self.time_to_dead - (time.time() - start_time)
203+
until_dead = max(until_dead, 1e-3)
204+
pass
205+
else:
206+
raise
207+
except Exception:
208+
if self._exiting:
209+
break
210+
else:
211+
raise
212+
else:
213+
break
214+
return events
215+
216+
def run(self):
217+
"""The thread's main activity. Call start() instead."""
218+
if self.loop is not None:
219+
asyncio.set_event_loop(self.loop)
220+
self._create_socket()
221+
self._running = True
222+
self._beating = True
223+
224+
while self._running:
225+
if self._pause:
226+
# just sleep, and skip the rest of the loop
227+
self._exit.wait(self.time_to_dead)
228+
continue
229+
230+
since_last_heartbeat = 0.0
231+
# no need to catch EFSM here, because the previous event was
232+
# either a recv or connect, which cannot be followed by EFSM
233+
self.socket.send(b'ping')
234+
request_time = time.time()
235+
ready = self._poll(request_time)
236+
if ready:
237+
self._beating = True
238+
# the poll above guarantees we have something to recv
239+
self.socket.recv()
240+
# sleep the remainder of the cycle
241+
remainder = self.time_to_dead - (time.time() - request_time)
242+
if remainder > 0:
243+
self._exit.wait(remainder)
244+
continue
245+
else:
246+
# nothing was received within the time limit, signal heart failure
247+
self._beating = False
248+
since_last_heartbeat = time.time() - request_time
249+
self.call_handlers(since_last_heartbeat)
250+
# and close/reopen the socket, because the REQ/REP cycle has been broken
251+
self._create_socket()
252+
continue
253+
254+
def pause(self):
255+
"""Pause the heartbeat."""
256+
self._pause = True
257+
258+
def unpause(self):
259+
"""Unpause the heartbeat."""
260+
self._pause = False
261+
262+
def is_beating(self):
263+
"""Is the heartbeat running and responsive (and not paused)."""
264+
if self.is_alive() and not self._pause and self._beating:
265+
return True
266+
else:
267+
return False
268+
269+
def stop(self):
270+
"""Stop the channel's event loop and join its thread."""
271+
self._running = False
272+
self._exit.set()
273+
self.join()
274+
self.close()
275+
276+
def close(self):
277+
if self.socket is not None:
278+
try:
279+
self.socket.close(linger=0)
280+
except Exception:
281+
pass
282+
self.socket = None
283+
284+
def call_handlers(self, since_last_heartbeat):
285+
"""This method is called in the ioloop thread when a message arrives.
286+
287+
Subclasses should override this method to handle incoming messages.
288+
It is important to remember that this method is called in the thread
289+
so that some logic must be done to ensure that the application level
290+
handlers are called in the application thread.
291+
"""
292+
pass
293+
294+
295+
HBChannelABC.register(HBChannel)

0 commit comments

Comments
 (0)