Skip to content

Commit a6f8e05

Browse files
committed
forward-port abort_queues race condition fix
abort_queues didn't work when do_execute is actually async Fix submitted upstream to ipykernel
1 parent 52267df commit a6f8e05

File tree

2 files changed

+64
-16
lines changed

2 files changed

+64
-16
lines changed

ipyparallel/engine/kernel.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import asyncio
33
import inspect
44
import sys
5+
from functools import partial
56

7+
import ipykernel
68
from ipykernel.ipkernel import IPythonKernel
79
from traitlets import Integer
810
from traitlets import Type
@@ -47,6 +49,47 @@ def __init__(self, **kwargs):
4749
data_pub.pub_socket = self.iopub_socket
4850
self.aborted = set()
4951

52+
def _abort_queues(self):
53+
# forward-port ipython/ipykernel#853
54+
# may remove after requiring ipykernel 6.9
55+
56+
# while this flag is true,
57+
# execute requests will be aborted
58+
self._aborting = True
59+
self.log.info("Aborting queue")
60+
61+
# Callback to signal that we are done aborting
62+
def stop_aborting():
63+
self.log.info("Finishing abort")
64+
self._aborting = False
65+
66+
# put stop_aborting on the message queue
67+
# so that it's handled after processing of already-pending messages
68+
if ipykernel.version_info < (6,):
69+
# 10 is SHELL priority in ipykernel 5.x
70+
streams = self.shell_streams
71+
schedule_stop_aborting = partial(self.schedule_dispatch, 10, stop_aborting)
72+
else:
73+
streams = [self.shell_stream]
74+
schedule_stop_aborting = partial(self.schedule_dispatch, stop_aborting)
75+
76+
# flush streams, so all currently waiting messages
77+
# are added to the queue
78+
for stream in streams:
79+
stream.flush()
80+
81+
# if we have a delay, give messages this long to arrive on the queue
82+
# before we start accepting requests
83+
asyncio.get_event_loop().call_later(
84+
self.stop_on_error_timeout, schedule_stop_aborting
85+
)
86+
87+
# for compatibility, return a completed Future
88+
# so this is still awaitable
89+
f = asyncio.Future()
90+
f.set_result(None)
91+
return f
92+
5093
def should_handle(self, stream, msg, idents):
5194
"""Check whether a shell-channel message should be handled
5295
@@ -194,22 +237,25 @@ def do_apply(self, content, bufs, msg_id, reply_metadata):
194237

195238
return reply_content, result_buf
196239

197-
def do_execute(self, *args, **kwargs):
240+
async def _do_execute_async(self, *args, **kwargs):
198241
super_execute = super().do_execute(*args, **kwargs)
242+
if inspect.isawaitable(super_execute):
243+
reply_content = await super_execute
244+
else:
245+
reply_content = super_execute
246+
# add engine info
247+
if reply_content['status'] == 'error':
248+
reply_content["engine_info"] = self.get_engine_info(method="execute")
249+
return reply_content
199250

200-
async def _do_execute():
201-
if inspect.isawaitable(super_execute):
202-
reply_content = await super_execute
203-
else:
204-
reply_content = super_execute
205-
# add engine info
206-
if reply_content['status'] == 'error':
207-
reply_content["engine_info"] = self.get_engine_info(method="execute")
208-
return reply_content
209-
210-
# ipykernel 5 uses gen.maybe_future which doesn't accept async def coroutines,
211-
# but it does accept asyncio.Futures
212-
return asyncio.ensure_future(_do_execute())
251+
def do_execute(self, *args, **kwargs):
252+
coro = self._do_execute_async(*args, **kwargs)
253+
if ipykernel.version_info < (6,):
254+
# ipykernel 5 uses gen.maybe_future which doesn't accept async def coroutines,
255+
# but it does accept asyncio.Futures
256+
return asyncio.ensure_future(coro)
257+
else:
258+
return coro
213259

214260
# Control messages for msgspec extensions:
215261

@@ -219,7 +265,9 @@ def abort_request(self, stream, ident, parent):
219265
if isinstance(msg_ids, str):
220266
msg_ids = [msg_ids]
221267
if not msg_ids:
222-
self._abort_queues()
268+
f = self._abort_queues()
269+
if inspect.isawaitable(f):
270+
asyncio.ensure_future(f)
223271
for mid in msg_ids:
224272
self.aborted.add(str(mid))
225273

ipyparallel/tests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def teardown():
124124
try:
125125
f = p.stop()
126126
if f:
127-
asyncio.run(f)
127+
asyncio.get_event_loop().run_until_complete(f)
128128
except Exception as e:
129129
print(e)
130130
pass

0 commit comments

Comments
 (0)