14
14
from .channelsabc import HBChannelABC
15
15
from .session import Session
16
16
from jupyter_client import protocol_version_info
17
+ from jupyter_client .utils import ensure_async
17
18
18
19
# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
19
20
# during garbage collection of threads at exit
@@ -49,15 +50,15 @@ class HBChannel(Thread):
49
50
50
51
def __init__ (
51
52
self ,
52
- context : t .Optional [zmq .asyncio . Context ] = None ,
53
+ context : t .Optional [zmq .Context ] = None ,
53
54
session : t .Optional [Session ] = None ,
54
55
address : t .Union [t .Tuple [str , int ], str ] = "" ,
55
56
):
56
57
"""Create the heartbeat monitor thread.
57
58
58
59
Parameters
59
60
----------
60
- context : :class:`zmq.asyncio. Context`
61
+ context : :class:`zmq.Context`
61
62
The ZMQ context to use.
62
63
session : :class:`session.Session`
63
64
The session to use.
@@ -106,12 +107,6 @@ def _create_socket(self) -> None:
106
107
107
108
self .poller .register (self .socket , zmq .POLLIN )
108
109
109
- def run (self ) -> None :
110
- loop = asyncio .new_event_loop ()
111
- asyncio .set_event_loop (loop )
112
- loop .run_until_complete (self ._async_run ())
113
- loop .close ()
114
-
115
110
async def _async_run (self ) -> None :
116
111
"""The thread's main activity. Call start() instead."""
117
112
self ._create_socket ()
@@ -127,16 +122,16 @@ async def _async_run(self) -> None:
127
122
128
123
since_last_heartbeat = 0.0
129
124
# no need to catch EFSM here, because the previous event was
130
- # either a recv or connect, which cannot be followed by EFSM
131
- await self .socket .send (b"ping" )
125
+ # either a recv or connect, which cannot be followed by EFSM)
126
+ await ensure_async ( self .socket .send (b"ping" ) )
132
127
request_time = time .time ()
133
128
# Wait until timeout
134
129
self ._exit .wait (self .time_to_dead )
135
130
# poll(0) means return immediately (see http://api.zeromq.org/2-1:zmq-poll)
136
131
self ._beating = bool (self .poller .poll (0 ))
137
132
if self ._beating :
138
133
# the poll above guarantees we have something to recv
139
- await self .socket .recv ()
134
+ await ensure_async ( self .socket .recv () )
140
135
continue
141
136
elif self ._running :
142
137
# nothing was received within the time limit, signal heart failure
@@ -146,6 +141,12 @@ async def _async_run(self) -> None:
146
141
self ._create_socket ()
147
142
continue
148
143
144
+ def run (self ) -> None :
145
+ loop = asyncio .new_event_loop ()
146
+ asyncio .set_event_loop (loop )
147
+ loop .run_until_complete (self ._async_run ())
148
+ loop .close ()
149
+
149
150
def pause (self ) -> None :
150
151
"""Pause the heartbeat."""
151
152
self ._pause = True
@@ -191,14 +192,14 @@ def call_handlers(self, since_last_heartbeat: float) -> None:
191
192
192
193
193
194
class ZMQSocketChannel (object ):
194
- """A ZMQ socket in an async API """
195
+ """A ZMQ socket wrapper """
195
196
196
- def __init__ (self , socket : zmq .asyncio . Socket , session : Session , loop : t .Any = None ) -> None :
197
+ def __init__ (self , socket : zmq .Socket , session : Session , loop : t .Any = None ) -> None :
197
198
"""Create a channel.
198
199
199
200
Parameters
200
201
----------
201
- socket : :class:`zmq.asyncio. Socket`
202
+ socket : :class:`zmq.Socket`
202
203
The ZMQ socket to use.
203
204
session : :class:`session.Session`
204
205
The session to use.
@@ -207,42 +208,41 @@ def __init__(self, socket: zmq.asyncio.Socket, session: Session, loop: t.Any = N
207
208
"""
208
209
super ().__init__ ()
209
210
210
- self .socket : t .Optional [zmq .asyncio . Socket ] = socket
211
+ self .socket : t .Optional [zmq .Socket ] = socket
211
212
self .session = session
212
213
213
- async def _recv (self , ** kwargs : t .Any ) -> t .Dict [str , t .Any ]:
214
+ def _recv (self , ** kwargs : t .Any ) -> t .Dict [str , t .Any ]:
214
215
assert self .socket is not None
215
- msg = await self .socket .recv_multipart (** kwargs )
216
+ msg = self .socket .recv_multipart (** kwargs )
216
217
ident , smsg = self .session .feed_identities (msg )
217
218
return self .session .deserialize (smsg )
218
219
219
- async def get_msg (self , timeout : t .Optional [float ] = None ) -> t .Dict [str , t .Any ]:
220
+ def get_msg (self , timeout : t .Optional [float ] = None ) -> t .Dict [str , t .Any ]:
220
221
"""Gets a message if there is one that is ready."""
221
222
assert self .socket is not None
222
223
if timeout is not None :
223
224
timeout *= 1000 # seconds to ms
224
- ready = await self .socket .poll (timeout )
225
-
225
+ ready = self .socket .poll (timeout )
226
226
if ready :
227
- res = await self ._recv ()
227
+ res = self ._recv ()
228
228
return res
229
229
else :
230
230
raise Empty
231
231
232
- async def get_msgs (self ) -> t .List [t .Dict [str , t .Any ]]:
232
+ def get_msgs (self ) -> t .List [t .Dict [str , t .Any ]]:
233
233
"""Get all messages that are currently ready."""
234
234
msgs = []
235
235
while True :
236
236
try :
237
- msgs .append (await self .get_msg ())
237
+ msgs .append (self .get_msg ())
238
238
except Empty :
239
239
break
240
240
return msgs
241
241
242
- async def msg_ready (self ) -> bool :
242
+ def msg_ready (self ) -> bool :
243
243
"""Is there a message that has been received?"""
244
244
assert self .socket is not None
245
- return bool (await self .socket .poll (timeout = 0 ))
245
+ return bool (self .socket .poll (timeout = 0 ))
246
246
247
247
def close (self ) -> None :
248
248
if self .socket is not None :
@@ -264,3 +264,60 @@ def send(self, msg: t.Dict[str, t.Any]) -> None:
264
264
265
265
def start (self ) -> None :
266
266
pass
267
+
268
+
269
+ class AsyncZMQSocketChannel (ZMQSocketChannel ):
270
+ """A ZMQ socket in an async API"""
271
+
272
+ socket : zmq .asyncio .Socket
273
+
274
+ def __init__ (self , socket : zmq .asyncio .Socket , session : Session , loop : t .Any = None ) -> None :
275
+ """Create a channel.
276
+
277
+ Parameters
278
+ ----------
279
+ socket : :class:`zmq.asyncio.Socket`
280
+ The ZMQ socket to use.
281
+ session : :class:`session.Session`
282
+ The session to use.
283
+ loop
284
+ Unused here, for other implementations
285
+ """
286
+ if not isinstance (socket , zmq .asyncio .Socket ):
287
+ raise ValueError ('Socket must be asyncio' )
288
+ super ().__init__ (socket , session )
289
+
290
+ async def _recv (self , ** kwargs : t .Any ) -> t .Dict [str , t .Any ]: # type:ignore[override]
291
+ assert self .socket is not None
292
+ msg = await self .socket .recv_multipart (** kwargs )
293
+ _ , smsg = self .session .feed_identities (msg )
294
+ return self .session .deserialize (smsg )
295
+
296
+ async def get_msg ( # type:ignore[override]
297
+ self , timeout : t .Optional [float ] = None
298
+ ) -> t .Dict [str , t .Any ]:
299
+ """Gets a message if there is one that is ready."""
300
+ assert self .socket is not None
301
+ if timeout is not None :
302
+ timeout *= 1000 # seconds to ms
303
+ ready = await self .socket .poll (timeout )
304
+ if ready :
305
+ res = await self ._recv ()
306
+ return res
307
+ else :
308
+ raise Empty
309
+
310
+ async def get_msgs (self ) -> t .List [t .Dict [str , t .Any ]]: # type:ignore[override]
311
+ """Get all messages that are currently ready."""
312
+ msgs = []
313
+ while True :
314
+ try :
315
+ msgs .append (await self .get_msg ())
316
+ except Empty :
317
+ break
318
+ return msgs
319
+
320
+ async def msg_ready (self ) -> bool : # type:ignore[override]
321
+ """Is there a message that has been received?"""
322
+ assert self .socket is not None
323
+ return bool (await self .socket .poll (timeout = 0 ))
0 commit comments