Skip to content

Commit efd53d7

Browse files
authored
Cache separate headers on subshell threads (#1414)
1 parent 7daba5b commit efd53d7

File tree

7 files changed

+134
-23
lines changed

7 files changed

+134
-23
lines changed

.github/workflows/downstream.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ jobs:
4141
test_command: pytest -vv -raXxs -W default --durations 10 --color=yes
4242

4343
jupyter_client:
44+
if: false
4445
runs-on: ubuntu-latest
4546
steps:
4647
- name: Checkout
@@ -55,6 +56,7 @@ jobs:
5556
package_name: jupyter_client
5657

5758
ipyparallel:
59+
if: false
5860
runs-on: ubuntu-latest
5961
timeout-minutes: 20
6062
steps:

ipykernel/displayhook.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import builtins
88
import sys
99
import typing as t
10+
from contextvars import ContextVar
1011

1112
from IPython.core.displayhook import DisplayHook
1213
from jupyter_client.session import Session, extract_header
13-
from traitlets import Any, Dict, Instance
14+
from traitlets import Any, Instance
1415

1516
from ipykernel.jsonutil import encode_images, json_clean
1617

@@ -25,7 +26,9 @@ def __init__(self, session, pub_socket):
2526
"""Initialize the hook."""
2627
self.session = session
2728
self.pub_socket = pub_socket
28-
self.parent_header = {}
29+
30+
self._parent_header: ContextVar[dict[str, Any]] = ContextVar("parent_header")
31+
self._parent_header.set({})
2932

3033
def get_execution_count(self):
3134
"""This method is replaced in kernelapp"""
@@ -45,12 +48,20 @@ def __call__(self, obj):
4548
"metadata": {},
4649
}
4750
self.session.send(
48-
self.pub_socket, "execute_result", contents, parent=self.parent_header, ident=self.topic
51+
self.pub_socket,
52+
"execute_result",
53+
contents,
54+
parent=self.parent_header,
55+
ident=self.topic,
4956
)
5057

58+
@property
59+
def parent_header(self):
60+
return self._parent_header.get()
61+
5162
def set_parent(self, parent):
5263
"""Set the parent header."""
53-
self.parent_header = extract_header(parent)
64+
self._parent_header.set(extract_header(parent))
5465

5566

5667
class ZMQShellDisplayHook(DisplayHook):
@@ -62,12 +73,21 @@ class ZMQShellDisplayHook(DisplayHook):
6273

6374
session = Instance(Session, allow_none=True)
6475
pub_socket = Any(allow_none=True)
65-
parent_header = Dict({})
76+
_parent_header: ContextVar[dict[str, Any]]
6677
msg: dict[str, t.Any] | None
6778

79+
def __init__(self, *args, **kwargs):
80+
super().__init__(*args, **kwargs)
81+
self._parent_header = ContextVar("parent_header")
82+
self._parent_header.set({})
83+
84+
@property
85+
def parent_header(self):
86+
return self._parent_header.get()
87+
6888
def set_parent(self, parent):
6989
"""Set the parent for outbound messages."""
70-
self.parent_header = extract_header(parent)
90+
self._parent_header.set(extract_header(parent))
7191

7292
def start_displayhook(self):
7393
"""Start the display hook."""

ipykernel/kernelbase.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import typing as t
1717
import uuid
1818
import warnings
19+
from contextvars import ContextVar
1920
from datetime import datetime
2021
from functools import partial
2122
from signal import SIGINT, SIGTERM, Signals, default_int_handler, signal
@@ -194,8 +195,10 @@ def _default_ident(self):
194195

195196
# track associations with current request
196197
_allow_stdin = Bool(False)
197-
_parents: Dict[str, t.Any] = Dict({"shell": {}, "control": {}})
198-
_parent_ident = Dict({"shell": b"", "control": b""})
198+
_control_parent: Dict[str, t.Any] = Dict({})
199+
_control_parent_ident: bytes = b""
200+
_shell_parent: ContextVar[dict[str, Any]]
201+
_shell_parent_ident: ContextVar[bytes]
199202

200203
@property
201204
def _parent_header(self):
@@ -302,6 +305,14 @@ def __init__(self, **kwargs):
302305
self.do_execute, ["cell_meta", "cell_id"]
303306
)
304307

308+
self._control_parent = {}
309+
self._control_parent_ident = b""
310+
311+
self._shell_parent = ContextVar("shell_parent")
312+
self._shell_parent.set({})
313+
self._shell_parent_ident = ContextVar("shell_parent_ident")
314+
self._shell_parent_ident.set(b"")
315+
305316
async def dispatch_control(self, msg):
306317
"""Dispatch a control request, ensuring only one message is processed at a time."""
307318
# Ensure only one control message is processed at a time
@@ -737,8 +748,12 @@ def set_parent(self, ident, parent, channel="shell"):
737748
The parent identity is used to route input_request messages
738749
on the stdin channel.
739750
"""
740-
self._parent_ident[channel] = ident
741-
self._parents[channel] = parent
751+
if channel == "control":
752+
self._control_parent_ident = ident
753+
self._control_parent = parent
754+
else:
755+
self._shell_parent_ident.set(ident)
756+
self._shell_parent.set(parent)
742757

743758
def get_parent(self, channel=None):
744759
"""Get the parent request associated with a channel.
@@ -763,7 +778,9 @@ def get_parent(self, channel=None):
763778
else:
764779
channel = "shell"
765780

766-
return self._parents.get(channel, {})
781+
if channel == "control":
782+
return self._control_parent
783+
return self._shell_parent.get()
767784

768785
def send_response(
769786
self,
@@ -1424,7 +1441,7 @@ def getpass(self, prompt="", stream=None):
14241441
)
14251442
return self._input_request(
14261443
prompt,
1427-
self._parent_ident["shell"],
1444+
self._shell_parent_ident.get(),
14281445
self.get_parent("shell"),
14291446
password=True,
14301447
)
@@ -1441,7 +1458,7 @@ def raw_input(self, prompt=""):
14411458
raise StdinNotImplementedError(msg)
14421459
return self._input_request(
14431460
str(prompt),
1444-
self._parent_ident["shell"],
1461+
self._shell_parent_ident.get(),
14451462
self.get_parent("shell"),
14461463
password=False,
14471464
)

ipykernel/zmqshell.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# Copyright (c) IPython Development Team.
1515
# Distributed under the terms of the Modified BSD License.
1616

17+
import contextvars
1718
import os
1819
import sys
1920
import threading
@@ -34,7 +35,7 @@
3435
from IPython.utils.process import arg_split, system # type:ignore[attr-defined]
3536
from jupyter_client.session import Session, extract_header
3637
from jupyter_core.paths import jupyter_runtime_dir
37-
from traitlets import Any, CBool, CBytes, Dict, Instance, Type, default, observe
38+
from traitlets import Any, CBool, CBytes, Instance, Type, default, observe
3839

3940
from ipykernel import connect_qtconsole, get_connection_file, get_connection_info
4041
from ipykernel.displayhook import ZMQShellDisplayHook
@@ -50,17 +51,26 @@ class ZMQDisplayPublisher(DisplayPublisher):
5051

5152
session = Instance(Session, allow_none=True)
5253
pub_socket = Any(allow_none=True)
53-
parent_header = Dict({})
54+
_parent_header: contextvars.ContextVar[dict[str, Any]]
5455
topic = CBytes(b"display_data")
5556

5657
# thread_local:
5758
# An attribute used to ensure the correct output message
5859
# is processed. See ipykernel Issue 113 for a discussion.
5960
_thread_local = Any()
6061

62+
def __init__(self, *args, **kwargs):
63+
super().__init__(*args, **kwargs)
64+
self._parent_header = contextvars.ContextVar("parent_header")
65+
self._parent_header.set({})
66+
67+
@property
68+
def parent_header(self):
69+
return self._parent_header.get()
70+
6171
def set_parent(self, parent):
6272
"""Set the parent for outbound messages."""
63-
self.parent_header = extract_header(parent)
73+
self._parent_header.set(extract_header(parent))
6474

6575
def _flush_streams(self):
6676
"""flush IO Streams prior to display"""
@@ -485,11 +495,14 @@ def __init__(self, *args, **kwargs):
485495
if "IPKernelApp" not in self.config:
486496
self.config.IPKernelApp.tqdm = "dummy value for https://github.com/tqdm/tqdm/pull/1628"
487497

498+
self._parent_header = contextvars.ContextVar("parent_header")
499+
self._parent_header.set({})
500+
488501
displayhook_class = Type(ZMQShellDisplayHook)
489502
display_pub_class = Type(ZMQDisplayPublisher)
490503
data_pub_class = Any()
491504
kernel = Any()
492-
parent_header = Any()
505+
_parent_header: contextvars.ContextVar[dict[str, Any]]
493506

494507
@default("banner1")
495508
def _default_banner1(self):
@@ -658,9 +671,13 @@ def set_next_input(self, text, replace=False):
658671
)
659672
self.payload_manager.write_payload(payload) # type:ignore[union-attr]
660673

674+
@property
675+
def parent_header(self):
676+
return self._parent_header.get()
677+
661678
def set_parent(self, parent):
662679
"""Set the parent header for associating output with its triggering input"""
663-
self.parent_header = parent
680+
self._parent_header.set(parent)
664681
self.displayhook.set_parent(parent) # type:ignore[attr-defined]
665682
self.display_pub.set_parent(parent) # type:ignore[attr-defined]
666683
if hasattr(self, "_data_pub"):

tests/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import no_type_check
55
from unittest.mock import MagicMock
66

7-
import pytest
7+
import pytest_asyncio
88
import zmq
99
from jupyter_client.session import Session
1010
from tornado.ioloop import IOLoop
@@ -143,15 +143,15 @@ def __init__(self, *args, **kwargs):
143143
super().__init__(*args, **kwargs)
144144

145145

146-
@pytest.fixture()
146+
@pytest_asyncio.fixture()
147147
def kernel():
148148
kernel = MockKernel()
149149
kernel.io_loop = IOLoop.current()
150150
yield kernel
151151
kernel.destroy()
152152

153153

154-
@pytest.fixture()
154+
@pytest_asyncio.fixture()
155155
def ipkernel():
156156
kernel = MockIPyKernel()
157157
kernel.io_loop = IOLoop.current()

tests/test_subshells.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
import platform
88
import time
9+
from collections import Counter
910

1011
import pytest
1112
from jupyter_client.blocking.client import BlockingKernelClient
1213

13-
from .utils import TIMEOUT, assemble_output, get_replies, get_reply, new_kernel
14+
from .utils import TIMEOUT, assemble_output, get_replies, get_reply, new_kernel, wait_for_idle
1415

1516
# Helpers
1617

@@ -258,3 +259,57 @@ def test_execute_stop_on_error(are_subshells):
258259
for subshell_id in subshell_ids:
259260
if subshell_id:
260261
delete_subshell_helper(kc, subshell_id)
262+
263+
264+
@pytest.mark.parametrize("are_subshells", [(False, True), (True, False), (True, True)])
265+
def test_idle_message_parent_headers(are_subshells):
266+
with new_kernel() as kc:
267+
# import time module on main shell.
268+
msg = kc.session.msg("execute_request", {"code": "import time"})
269+
kc.shell_channel.send(msg)
270+
271+
subshell_ids = [
272+
create_subshell_helper(kc)["subshell_id"] if is_subshell else None
273+
for is_subshell in are_subshells
274+
]
275+
276+
# Wait for all idle status messages to be received.
277+
for _ in range(1 + sum(are_subshells)):
278+
wait_for_idle(kc)
279+
280+
msg_ids = []
281+
for subshell_id in subshell_ids:
282+
msg = execute_request(kc, "time.sleep(0.5)", subshell_id)
283+
msg_ids.append(msg["msg_id"])
284+
285+
# Expect 4 status messages (2 busy, 2 idle) on iopub channel for the two execute_requests
286+
statuses = []
287+
timeout = TIMEOUT # Combined timeout to receive all the status messages
288+
t0 = time.time()
289+
while True:
290+
status = kc.get_iopub_msg(timeout=timeout)
291+
if status["msg_type"] != "status" or status["parent_header"]["msg_id"] not in msg_ids:
292+
continue
293+
statuses.append(status)
294+
if len(statuses) == 4:
295+
break
296+
t1 = time.time()
297+
timeout -= t1 - t0
298+
t0 = t1
299+
300+
execution_states = Counter(msg["content"]["execution_state"] for msg in statuses)
301+
assert execution_states["busy"] == 2
302+
assert execution_states["idle"] == 2
303+
304+
parent_msg_ids = Counter(msg["parent_header"]["msg_id"] for msg in statuses)
305+
assert parent_msg_ids[msg_ids[0]] == 2
306+
assert parent_msg_ids[msg_ids[1]] == 2
307+
308+
parent_subshell_ids = Counter(msg["parent_header"].get("subshell_id") for msg in statuses)
309+
assert parent_subshell_ids[subshell_ids[0]] == 2
310+
assert parent_subshell_ids[subshell_ids[1]] == 2
311+
312+
# Cleanup
313+
for subshell_id in subshell_ids:
314+
if subshell_id:
315+
delete_subshell_helper(kc, subshell_id)

tests/test_zmq_shell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def test_zmq_interactive_shell(kernel):
245245
shell.data_pub
246246
shell.kernel = kernel
247247
shell.set_next_input("hi")
248-
assert shell.get_parent() is None
248+
assert shell.get_parent() == {}
249249
if os.name == "posix":
250250
shell.system_piped("ls")
251251
else:

0 commit comments

Comments
 (0)