|
4 | 4 | # Distributed under the terms of the Modified BSD License. |
5 | 5 |
|
6 | 6 | import asyncio |
| 7 | +import concurrent.futures |
7 | 8 | from datetime import datetime |
8 | 9 | from functools import partial |
9 | 10 | import itertools |
@@ -213,8 +214,34 @@ def dispatch_control(self, msg): |
213 | 214 | async def poll_control_queue(self): |
214 | 215 | while True: |
215 | 216 | msg = await self.control_queue.get() |
| 217 | + # handle tracers from _flush_control_queue |
| 218 | + if isinstance(msg, (concurrent.futures.Future, asyncio.Future)): |
| 219 | + msg.set_result(None) |
| 220 | + continue |
216 | 221 | await self.process_control(msg) |
217 | 222 |
|
| 223 | + async def _flush_control_queue(self): |
| 224 | + """Flush the control queue, wait for processing of any pending messages""" |
| 225 | + if self.control_thread: |
| 226 | + control_loop = self.control_thread.io_loop |
| 227 | + # concurrent.futures.Futures are threadsafe |
| 228 | + # and can be used to await across threads |
| 229 | + tracer_future = concurrent.futures.Future() |
| 230 | + awaitable_future = asyncio.wrap_future(tracer_future) |
| 231 | + else: |
| 232 | + control_loop = self.io_loop |
| 233 | + tracer_future = awaitable_future = asyncio.Future() |
| 234 | + |
| 235 | + def _flush(): |
| 236 | + # control_stream.flush puts messages on the queue |
| 237 | + self.control_stream.flush() |
| 238 | + # put Future on the queue after all of those, |
| 239 | + # so we can wait for all queued messages to be processed |
| 240 | + self.control_queue.put(tracer_future) |
| 241 | + |
| 242 | + control_loop.add_callback(_flush) |
| 243 | + return awaitable_future |
| 244 | + |
218 | 245 | async def process_control(self, msg): |
219 | 246 | """dispatch control requests""" |
220 | 247 | idents, msg = self.session.feed_identities(msg, copy=False) |
@@ -265,6 +292,10 @@ def should_handle(self, stream, msg, idents): |
265 | 292 |
|
266 | 293 | async def dispatch_shell(self, msg): |
267 | 294 | """dispatch shell requests""" |
| 295 | + |
| 296 | + # flush control queue before handling shell requests |
| 297 | + await self._flush_control_queue() |
| 298 | + |
268 | 299 | idents, msg = self.session.feed_identities(msg, copy=False) |
269 | 300 | try: |
270 | 301 | msg = self.session.deserialize(msg, content=True, copy=False) |
@@ -630,7 +661,7 @@ async def inspect_request(self, stream, ident, parent): |
630 | 661 | content.get('detail_level', 0), |
631 | 662 | ) |
632 | 663 | if inspect.isawaitable(reply_content): |
633 | | - reply_content = await reply_content |
| 664 | + reply_content = await reply_content |
634 | 665 |
|
635 | 666 | # Before we send this object over, we scrub it for JSON usage |
636 | 667 | reply_content = json_clean(reply_content) |
@@ -944,7 +975,7 @@ def _input_request(self, prompt, ident, parent, password=False): |
944 | 975 | raise KeyboardInterrupt("Interrupted by user") from None |
945 | 976 | except Exception as e: |
946 | 977 | self.log.warning("Invalid Message:", exc_info=True) |
947 | | - |
| 978 | + |
948 | 979 | try: |
949 | 980 | value = reply["content"]["value"] |
950 | 981 | except Exception: |
|
0 commit comments