Skip to content

Commit fd9591d

Browse files
committed
store and use a reference to the kernel launching thread instead of assuming it's always the main thread
1 parent dccf9c3 commit fd9591d

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

ipykernel/ipkernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ async def run_cell(*args, **kwargs):
452452

453453
cm = (
454454
self._cancel_on_sigint
455-
if threading.current_thread() == threading.main_thread()
455+
if threading.current_thread() == self.shell_channel_thread.parent_thread
456456
else self._dummy_context_manager
457457
)
458458
with cm(coro_future):

ipykernel/kernelbase.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -677,19 +677,19 @@ async def shell_main(self, subshell_id: str | None, msg):
677677
"""Handler of shell messages for a single subshell"""
678678
if self._supports_kernel_subshells:
679679
if subshell_id is None:
680-
assert threading.current_thread() == threading.main_thread()
680+
assert threading.current_thread() == self.shell_channel_thread.parent_thread
681681
asyncio_lock = self._main_asyncio_lock
682682
else:
683683
assert threading.current_thread() not in (
684684
self.shell_channel_thread,
685-
threading.main_thread(),
685+
self.shell_channel_thread.parent_thread,
686686
)
687687
asyncio_lock = self.shell_channel_thread.manager.get_subshell_asyncio_lock(
688688
subshell_id
689689
)
690690
else:
691691
assert subshell_id is None
692-
assert threading.current_thread() == threading.main_thread()
692+
assert threading.current_thread() == self.shell_channel_thread.parent_thread
693693
asyncio_lock = self._main_asyncio_lock
694694

695695
# Whilst executing a shell message, do not accept any other shell messages on the

ipykernel/shellchannel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
from typing import Any
7+
from threading import current_thread
78

89
import zmq
910

@@ -28,13 +29,16 @@ def __init__(
2829
self._manager: SubshellManager | None = None
2930
self._zmq_context = context # Avoid use of self._context
3031
self._shell_socket = shell_socket
32+
# Record the parent thread - the thread that started the app (usually the main thread)
33+
self.parent_thread = current_thread()
3134

3235
self.asyncio_lock = asyncio.Lock()
3336

3437
@property
3538
def manager(self) -> SubshellManager:
3639
# Lazy initialisation.
3740
if self._manager is None:
41+
assert current_thread() == self.parent_thread
3842
self._manager = SubshellManager(
3943
self._zmq_context,
4044
self.io_loop,

ipykernel/subshell_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import typing as t
88
import uuid
99
from functools import partial
10-
from threading import Lock, current_thread, main_thread
10+
from threading import Lock, current_thread
1111

1212
import zmq
1313
from tornado.ioloop import IOLoop
@@ -41,7 +41,7 @@ def __init__(
4141
shell_socket: zmq.Socket[t.Any],
4242
):
4343
"""Initialize the subshell manager."""
44-
assert current_thread() == main_thread()
44+
self._parent_thread = current_thread()
4545

4646
self._context: zmq.Context[t.Any] = context
4747
self._shell_channel_io_loop = shell_channel_io_loop
@@ -127,7 +127,7 @@ def set_on_recv_callback(self, on_recv_callback):
127127
"""Set the callback used by the main shell and all subshells to receive
128128
messages sent from the shell channel thread.
129129
"""
130-
assert current_thread() == main_thread()
130+
assert current_thread() == self._parent_thread
131131
self._on_recv_callback = on_recv_callback
132132
self._shell_channel_to_main.on_recv(IOLoop.current(), partial(self._on_recv_callback, None))
133133

@@ -144,7 +144,7 @@ def subshell_id_from_thread_id(self, thread_id: int) -> str | None:
144144
Only used by %subshell magic so does not have to be fast/cached.
145145
"""
146146
with self._lock_cache:
147-
if thread_id == main_thread().ident:
147+
if thread_id == self._parent_thread.ident:
148148
return None
149149
for id, subshell in self._cache.items():
150150
if subshell.ident == thread_id:

0 commit comments

Comments
 (0)