Skip to content

Commit 62300ed

Browse files
committed
Simplify abort handling and make send_stream a required parameter for processing shell messages.
1 parent 5f71eaa commit 62300ed

File tree

1 file changed

+12
-25
lines changed

1 file changed

+12
-25
lines changed

ipykernel/kernelbase.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
to_thread,
4747
)
4848
from anyio.abc import TaskStatus
49+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4950
from IPython.core.error import StdinNotImplementedError
5051
from jupyter_client.session import Session
5152
from traitlets.config.configurable import SingletonConfigurable
@@ -92,7 +93,7 @@ def _accepts_parameters(meth, param_names):
9293
class Kernel(SingletonConfigurable):
9394
"""The base kernel class."""
9495

95-
_aborted_time: float
96+
_aborted_time: float = time.monotonic()
9697

9798
# ---------------------------------------------------------------------------
9899
# Kernel interface
@@ -436,11 +437,11 @@ async def shell_main(self, subshell_id: str | None):
436437
await to_thread.run_sync(self.shell_stop.wait)
437438
tg.cancel_scope.cancel()
438439

439-
async def _execute_request_handler(self, receive_stream):
440+
async def _execute_request_handler(self, receive_stream: MemoryObjectReceiveStream):
440441
async with receive_stream:
441-
async for handler, (socket, idents, msg) in receive_stream:
442+
async for handler, (received_time, socket, idents, msg) in receive_stream:
442443
try:
443-
if self._aborting:
444+
if received_time < self._aborted_time:
444445
await self._send_abort_reply(socket, msg, idents)
445446
continue
446447
result = handler(socket, idents, msg)
@@ -450,7 +451,7 @@ async def _execute_request_handler(self, receive_stream):
450451
except BaseException as e:
451452
self.log.exception("Execute request", exc_info=e)
452453

453-
async def process_shell(self, socket, send_stream):
454+
async def process_shell(self, socket, send_stream: MemoryObjectSendStream):
454455
# socket=None is valid if kernel subshells are not supported.
455456
try:
456457
while True:
@@ -460,7 +461,9 @@ async def process_shell(self, socket, send_stream):
460461
return
461462
raise
462463

463-
async def process_shell_message(self, msg=None, socket=None, send_stream=None):
464+
async def process_shell_message(
465+
self, msg=None, socket=None, *, send_stream: MemoryObjectSendStream
466+
):
464467
# If socket is None kernel subshells are not supported so use socket=shell_socket.
465468
# If msg is set, process that message.
466469
# If msg is None, await the next message to arrive on the socket.
@@ -476,10 +479,8 @@ async def process_shell_message(self, msg=None, socket=None, send_stream=None):
476479
assert socket is None
477480
socket = self.shell_socket
478481

479-
no_msg = msg is None if self._is_test else not await socket.apoll(0).wait()
480482
msg = msg or await socket.arecv_multipart(copy=False).wait()
481483

482-
received_time = time.monotonic()
483484
copy = not isinstance(msg[0], zmq.Message)
484485
idents, msg = self.session.feed_identities(msg, copy=copy)
485486
try:
@@ -494,18 +495,6 @@ async def process_shell_message(self, msg=None, socket=None, send_stream=None):
494495

495496
msg_type = msg["header"]["msg_type"]
496497

497-
# Only abort execute requests
498-
if self._aborting and msg_type == "execute_request":
499-
if not self.stop_on_error_timeout:
500-
if no_msg:
501-
self._aborting = False
502-
elif received_time - self._aborted_time > self.stop_on_error_timeout:
503-
self._aborting = False
504-
if self._aborting:
505-
await self._send_abort_reply(socket, msg, idents)
506-
self._publish_status("idle", "shell")
507-
return
508-
509498
# Print some info about this message and leave a '--->' marker, so it's
510499
# easier to trace visually the message chain when debugging. Each
511500
# handler prints its message at the end.
@@ -529,8 +518,8 @@ async def process_shell_message(self, msg=None, socket=None, send_stream=None):
529518
except Exception:
530519
self.log.debug("Unable to signal in pre_handler_hook:", exc_info=True)
531520
try:
532-
if msg_type == "execute_request" and send_stream:
533-
await send_stream.send((handler, (socket, idents, msg)))
521+
if msg_type == "execute_request":
522+
await send_stream.send((handler, (time.monotonic(), socket, idents, msg)))
534523
else:
535524
result = handler(socket, idents, msg)
536525
if inspect.isawaitable(result):
@@ -824,9 +813,7 @@ async def execute_request(self, socket, ident, parent):
824813

825814
assert reply_msg is not None
826815
if not silent and reply_msg["content"]["status"] == "error" and stop_on_error:
827-
# while this flag is true,
828-
# execute requests will be aborted
829-
self._aborting = True
816+
# execute requests will be aborted if the received time is prior to the _aborted_time
830817
self._aborted_time = time.monotonic()
831818
self.log.info("Aborting queue")
832819

0 commit comments

Comments
 (0)