Skip to content

Commit 7b41bc6

Browse files
authored
Use asyncio.Lock around subshell message handling (#1430)
1 parent 5b8ac29 commit 7b41bc6

File tree

6 files changed

+54
-58
lines changed

6 files changed

+54
-58
lines changed

.github/workflows/downstream.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ jobs:
9090
9191
qtconsole:
9292
runs-on: ubuntu-latest
93-
if: false
9493
timeout-minutes: 20
9594
steps:
9695
- name: Checkout

ipykernel/kernelbase.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ def _default_ident(self):
205205
# see https://github.com/jupyterlab/jupyterlab/issues/17785
206206
_parent_ident: Mapping[str, bytes]
207207

208+
# Asyncio lock for main shell thread.
209+
_main_asyncio_lock: asyncio.Lock
210+
208211
@property
209212
def _parent_header(self):
210213
warnings.warn(
@@ -327,6 +330,8 @@ def __init__(self, **kwargs):
327330
}
328331
)
329332

333+
self._main_asyncio_lock = asyncio.Lock()
334+
330335
async def dispatch_control(self, msg):
331336
"""Dispatch a control request, ensuring only one message is processed at a time."""
332337
# Ensure only one control message is processed at a time
@@ -539,7 +544,7 @@ async def do_one_iteration(self):
539544
This is now a coroutine
540545
"""
541546
# flush messages off of shell stream into the message queue
542-
if self.shell_stream:
547+
if self.shell_stream and not self._supports_kernel_subshells:
543548
self.shell_stream.flush()
544549
# process at most one shell message per iteration
545550
await self.process_one(wait=False)
@@ -649,56 +654,51 @@ async def shell_channel_thread_main(self, msg):
649654
"""Handler for shell messages received on shell_channel_thread"""
650655
assert threading.current_thread() == self.shell_channel_thread
651656

652-
if self.session is None:
653-
return
654-
655-
# deserialize only the header to get subshell_id
656-
# Keep original message to send to subshell_id unmodified.
657-
_, msg2 = self.session.feed_identities(msg, copy=False)
658-
try:
659-
msg3 = self.session.deserialize(msg2, content=False, copy=False)
660-
subshell_id = msg3["header"].get("subshell_id")
661-
662-
# Find inproc pair socket to use to send message to correct subshell.
663-
subshell_manager = self.shell_channel_thread.manager
664-
socket = subshell_manager.get_shell_channel_to_subshell_socket(subshell_id)
665-
assert socket is not None
666-
socket.send_multipart(msg, copy=False)
667-
except Exception:
668-
self.log.error("Invalid message", exc_info=True) # noqa: G201
657+
async with self.shell_channel_thread.asyncio_lock:
658+
if self.session is None:
659+
return
669660

670-
if self.shell_stream:
671-
self.shell_stream.flush()
661+
# deserialize only the header to get subshell_id
662+
# Keep original message to send to subshell_id unmodified.
663+
_, msg2 = self.session.feed_identities(msg, copy=False)
664+
try:
665+
msg3 = self.session.deserialize(msg2, content=False, copy=False)
666+
subshell_id = msg3["header"].get("subshell_id")
667+
668+
# Find inproc pair socket to use to send message to correct subshell.
669+
subshell_manager = self.shell_channel_thread.manager
670+
socket = subshell_manager.get_shell_channel_to_subshell_socket(subshell_id)
671+
assert socket is not None
672+
socket.send_multipart(msg, copy=False)
673+
except Exception:
674+
self.log.error("Invalid message", exc_info=True) # noqa: G201
672675

673676
async def shell_main(self, subshell_id: str | None, msg):
674677
"""Handler of shell messages for a single subshell"""
675678
if self._supports_kernel_subshells:
676679
if subshell_id is None:
677680
assert threading.current_thread() == threading.main_thread()
681+
asyncio_lock = self._main_asyncio_lock
678682
else:
679683
assert threading.current_thread() not in (
680684
self.shell_channel_thread,
681685
threading.main_thread(),
682686
)
683-
socket_pair = self.shell_channel_thread.manager.get_shell_channel_to_subshell_pair(
684-
subshell_id
685-
)
687+
asyncio_lock = self.shell_channel_thread.manager.get_subshell_asyncio_lock(
688+
subshell_id
689+
)
686690
else:
687691
assert subshell_id is None
688692
assert threading.current_thread() == threading.main_thread()
689-
socket_pair = None
690-
691-
try:
692-
# Whilst executing a shell message, do not accept any other shell messages on the
693-
# same subshell, so that cells are run sequentially. Without this we can run multiple
694-
# async cells at the same time which would be a nice feature to have but is an API
695-
# change.
696-
if socket_pair:
697-
socket_pair.pause_on_recv()
693+
asyncio_lock = self._main_asyncio_lock
694+
695+
# Whilst executing a shell message, do not accept any other shell messages on the
696+
# same subshell, so that cells are run sequentially. Without this we can run multiple
697+
# async cells at the same time which would be a nice feature to have but is an API
698+
# change.
699+
assert asyncio_lock is not None
700+
async with asyncio_lock:
698701
await self.dispatch_shell(msg, subshell_id=subshell_id)
699-
finally:
700-
if socket_pair:
701-
socket_pair.resume_on_recv()
702702

703703
def record_ports(self, ports):
704704
"""Record the ports that this kernel is using.
@@ -739,7 +739,7 @@ def _publish_status(self, status, channel, parent=None):
739739
def _publish_status_and_flush(self, status, channel, stream, parent=None):
740740
"""send status on IOPub and flush specified stream to ensure reply is sent before handling the next reply"""
741741
self._publish_status(status, channel, parent)
742-
if stream and hasattr(stream, "flush"):
742+
if stream and hasattr(stream, "flush") and not self._supports_kernel_subshells:
743743
stream.flush(zmq.POLLOUT)
744744

745745
def _publish_debug_event(self, event):
@@ -1382,7 +1382,7 @@ def _abort_queues(self, subshell_id: str | None = None):
13821382

13831383
# flush streams, so all currently waiting messages
13841384
# are added to the queue
1385-
if self.shell_stream:
1385+
if self.shell_stream and not self._supports_kernel_subshells:
13861386
self.shell_stream.flush()
13871387

13881388
# Callback to signal that we are done aborting

ipykernel/shellchannel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
from typing import Any
67

78
import zmq
@@ -28,6 +29,8 @@ def __init__(
2829
self._context = context
2930
self._shell_socket = shell_socket
3031

32+
self.asyncio_lock = asyncio.Lock()
33+
3134
@property
3235
def manager(self) -> SubshellManager:
3336
# Lazy initialisation.

ipykernel/socket_pair.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ class SocketPair:
2121
from_socket: zmq.Socket[Any]
2222
to_socket: zmq.Socket[Any]
2323
to_stream: ZMQStream | None = None
24-
on_recv_callback: Any
25-
on_recv_copy: bool
2624

2725
def __init__(self, context: zmq.Context[Any], name: str):
2826
"""Initialize the inproc socker pair."""
@@ -43,21 +41,9 @@ def close(self):
4341
def on_recv(self, io_loop: IOLoop, on_recv_callback, copy: bool = False):
4442
"""Set the callback used when a message is received on the to stream."""
4543
# io_loop is that of the 'to' thread.
46-
self.on_recv_callback = on_recv_callback
47-
self.on_recv_copy = copy
4844
if self.to_stream is None:
4945
self.to_stream = ZMQStream(self.to_socket, io_loop)
50-
self.resume_on_recv()
51-
52-
def pause_on_recv(self):
53-
"""Pause receiving on the to stream."""
54-
if self.to_stream is not None:
55-
self.to_stream.stop_on_recv()
56-
57-
def resume_on_recv(self):
58-
"""Resume receiving on the to stream."""
59-
if self.to_stream is not None and not self.to_stream.closed():
60-
self.to_stream.on_recv(self.on_recv_callback, copy=self.on_recv_copy)
46+
self.to_stream.on_recv(on_recv_callback, copy=copy)
6147

6248
def _address(self, name) -> str:
6349
"""Return the address used for this inproc socket pair."""

ipykernel/subshell.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""A thread for a subshell."""
22

3+
import asyncio
34
from typing import Any
45

56
import zmq
@@ -29,6 +30,8 @@ def __init__(
2930
# When aborting flag is set, execute_request messages to this subshell will be aborted.
3031
self.aborting = False
3132

33+
self.asyncio_lock = asyncio.Lock()
34+
3235
def run(self) -> None:
3336
"""Run the thread."""
3437
try:

ipykernel/subshell_manager.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import json
67
import typing as t
78
import uuid
@@ -46,8 +47,7 @@ def __init__(
4647
self._shell_channel_io_loop = shell_channel_io_loop
4748
self._shell_socket = shell_socket
4849
self._cache: dict[str, SubshellThread] = {}
49-
self._lock_cache = Lock()
50-
self._lock_shell_socket = Lock()
50+
self._lock_cache = Lock() # Sync lock across threads when accessing cache.
5151

5252
# Inproc socket pair for communication from control thread to shell channel thread,
5353
# such as for create_subshell_request messages. Reply messages are returned straight away.
@@ -107,7 +107,13 @@ def get_shell_channel_to_subshell_socket(self, subshell_id: str | None) -> zmq.S
107107

108108
def get_subshell_aborting(self, subshell_id: str) -> bool:
109109
"""Get the boolean aborting flag of the specified subshell."""
110-
return self._cache[subshell_id].aborting
110+
with self._lock_cache:
111+
return self._cache[subshell_id].aborting
112+
113+
def get_subshell_asyncio_lock(self, subshell_id: str) -> asyncio.Lock:
114+
"""Return the asyncio lock belonging to the specified subshell."""
115+
with self._lock_cache:
116+
return self._cache[subshell_id].asyncio_lock
111117

112118
def list_subshell(self) -> list[str]:
113119
"""Return list of current subshell ids.
@@ -216,8 +222,7 @@ def _process_control_request(
216222

217223
def _send_on_shell_channel(self, msg) -> None:
218224
assert current_thread().name == SHELL_CHANNEL_THREAD_NAME
219-
with self._lock_shell_socket:
220-
self._shell_socket.send_multipart(msg)
225+
self._shell_socket.send_multipart(msg)
221226

222227
def _stop_subshell(self, subshell_thread: SubshellThread) -> None:
223228
"""Stop a subshell thread and close all of its resources."""

0 commit comments

Comments
 (0)