Skip to content

Commit 5410bd9

Browse files
committed
Cache separate headers on subshell threads
1 parent 7daba5b commit 5410bd9

File tree

4 files changed

+123
-18
lines changed

4 files changed

+123
-18
lines changed

ipykernel/displayhook.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import builtins
88
import sys
99
import typing as t
10+
from contextvars import ContextVar
1011

1112
from IPython.core.displayhook import DisplayHook
1213
from jupyter_client.session import Session, extract_header
13-
from traitlets import Any, Dict, Instance
14+
from traitlets import Any, Instance
1415

1516
from ipykernel.jsonutil import encode_images, json_clean
1617

@@ -25,7 +26,9 @@ def __init__(self, session, pub_socket):
2526
"""Initialize the hook."""
2627
self.session = session
2728
self.pub_socket = pub_socket
28-
self.parent_header = {}
29+
30+
self._parent_header: ContextVar[dict[str, Any]] = ContextVar("parent_header")
31+
self._parent_header.set({})
2932

3033
def get_execution_count(self):
3134
"""This method is replaced in kernelapp"""
@@ -45,12 +48,20 @@ def __call__(self, obj):
4548
"metadata": {},
4649
}
4750
self.session.send(
48-
self.pub_socket, "execute_result", contents, parent=self.parent_header, ident=self.topic
51+
self.pub_socket,
52+
"execute_result",
53+
contents,
54+
parent=self.parent_header,
55+
ident=self.topic,
4956
)
5057

58+
@property
59+
def parent_header(self):
60+
return self._parent_header.get()
61+
5162
def set_parent(self, parent):
5263
"""Set the parent header."""
53-
self.parent_header = extract_header(parent)
64+
self._parent_header.set(extract_header(parent))
5465

5566

5667
class ZMQShellDisplayHook(DisplayHook):
@@ -62,12 +73,21 @@ class ZMQShellDisplayHook(DisplayHook):
6273

6374
session = Instance(Session, allow_none=True)
6475
pub_socket = Any(allow_none=True)
65-
parent_header = Dict({})
76+
_parent_header: ContextVar[dict[str, Any]]
6677
msg: dict[str, t.Any] | None
6778

79+
def __init__(self, *args, **kwargs):
80+
super().__init__(*args, **kwargs)
81+
self._parent_header = ContextVar("parent_header")
82+
self._parent_header.set({})
83+
84+
@property
85+
def parent_header(self):
86+
return self._parent_header.get()
87+
6888
def set_parent(self, parent):
6989
"""Set the parent for outbound messages."""
70-
self.parent_header = extract_header(parent)
90+
self._parent_header.set(extract_header(parent))
7191

7292
def start_displayhook(self):
7393
"""Start the display hook."""

ipykernel/kernelbase.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import typing as t
1717
import uuid
1818
import warnings
19+
from contextvars import ContextVar
1920
from datetime import datetime
2021
from functools import partial
2122
from signal import SIGINT, SIGTERM, Signals, default_int_handler, signal
@@ -194,8 +195,10 @@ def _default_ident(self):
194195

195196
# track associations with current request
196197
_allow_stdin = Bool(False)
197-
_parents: Dict[str, t.Any] = Dict({"shell": {}, "control": {}})
198-
_parent_ident = Dict({"shell": b"", "control": b""})
198+
_control_parent: Dict[str, t.Any] = Dict({})
199+
_control_parent_ident: bytes = b""
200+
_shell_parent: ContextVar[dict[str, Any]]
201+
_shell_parent_ident: ContextVar[bytes]
199202

200203
@property
201204
def _parent_header(self):
@@ -302,6 +305,14 @@ def __init__(self, **kwargs):
302305
self.do_execute, ["cell_meta", "cell_id"]
303306
)
304307

308+
self._control_parent = {}
309+
self._control_parent_ident = b""
310+
311+
self._shell_parent = ContextVar("shell_parent")
312+
self._shell_parent.set({})
313+
self._shell_parent_ident = ContextVar("shell_parent_ident")
314+
self._shell_parent_ident.set(b"")
315+
305316
async def dispatch_control(self, msg):
306317
"""Dispatch a control request, ensuring only one message is processed at a time."""
307318
# Ensure only one control message is processed at a time
@@ -737,8 +748,12 @@ def set_parent(self, ident, parent, channel="shell"):
737748
The parent identity is used to route input_request messages
738749
on the stdin channel.
739750
"""
740-
self._parent_ident[channel] = ident
741-
self._parents[channel] = parent
751+
if channel == "control":
752+
self._control_parent_ident = ident
753+
self._control_parent = parent
754+
else:
755+
self._shell_parent_ident.set(ident)
756+
self._shell_parent.set(parent)
742757

743758
def get_parent(self, channel=None):
744759
"""Get the parent request associated with a channel.
@@ -763,7 +778,9 @@ def get_parent(self, channel=None):
763778
else:
764779
channel = "shell"
765780

766-
return self._parents.get(channel, {})
781+
if channel == "control":
782+
return self._control_parent
783+
return self._shell_parent.get()
767784

768785
def send_response(
769786
self,
@@ -1424,7 +1441,7 @@ def getpass(self, prompt="", stream=None):
14241441
)
14251442
return self._input_request(
14261443
prompt,
1427-
self._parent_ident["shell"],
1444+
self._shell_parent_ident.get(),
14281445
self.get_parent("shell"),
14291446
password=True,
14301447
)
@@ -1441,7 +1458,7 @@ def raw_input(self, prompt=""):
14411458
raise StdinNotImplementedError(msg)
14421459
return self._input_request(
14431460
str(prompt),
1444-
self._parent_ident["shell"],
1461+
self._shell_parent_ident.get(),
14451462
self.get_parent("shell"),
14461463
password=False,
14471464
)

ipykernel/zmqshell.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# Copyright (c) IPython Development Team.
1515
# Distributed under the terms of the Modified BSD License.
1616

17+
import contextvars
1718
import os
1819
import sys
1920
import threading
@@ -34,7 +35,7 @@
3435
from IPython.utils.process import arg_split, system # type:ignore[attr-defined]
3536
from jupyter_client.session import Session, extract_header
3637
from jupyter_core.paths import jupyter_runtime_dir
37-
from traitlets import Any, CBool, CBytes, Dict, Instance, Type, default, observe
38+
from traitlets import Any, CBool, CBytes, Instance, Type, default, observe
3839

3940
from ipykernel import connect_qtconsole, get_connection_file, get_connection_info
4041
from ipykernel.displayhook import ZMQShellDisplayHook
@@ -50,17 +51,26 @@ class ZMQDisplayPublisher(DisplayPublisher):
5051

5152
session = Instance(Session, allow_none=True)
5253
pub_socket = Any(allow_none=True)
53-
parent_header = Dict({})
54+
_parent_header: contextvars.ContextVar[dict[str, Any]]
5455
topic = CBytes(b"display_data")
5556

5657
# thread_local:
5758
# An attribute used to ensure the correct output message
5859
# is processed. See ipykernel Issue 113 for a discussion.
5960
_thread_local = Any()
6061

62+
def __init__(self, *args, **kwargs):
63+
super().__init__(*args, **kwargs)
64+
self._parent_header = contextvars.ContextVar("parent_header")
65+
self._parent_header.set({})
66+
67+
@property
68+
def parent_header(self):
69+
return self._parent_header.get()
70+
6171
def set_parent(self, parent):
6272
"""Set the parent for outbound messages."""
63-
self.parent_header = extract_header(parent)
73+
self._parent_header.set(extract_header(parent))
6474

6575
def _flush_streams(self):
6676
"""flush IO Streams prior to display"""
@@ -485,11 +495,14 @@ def __init__(self, *args, **kwargs):
485495
if "IPKernelApp" not in self.config:
486496
self.config.IPKernelApp.tqdm = "dummy value for https://github.com/tqdm/tqdm/pull/1628"
487497

498+
self._parent_header = contextvars.ContextVar("parent_header")
499+
self._parent_header.set({})
500+
488501
displayhook_class = Type(ZMQShellDisplayHook)
489502
display_pub_class = Type(ZMQDisplayPublisher)
490503
data_pub_class = Any()
491504
kernel = Any()
492-
parent_header = Any()
505+
_parent_header: contextvars.ContextVar[dict[str, Any]]
493506

494507
@default("banner1")
495508
def _default_banner1(self):
@@ -658,9 +671,13 @@ def set_next_input(self, text, replace=False):
658671
)
659672
self.payload_manager.write_payload(payload) # type:ignore[union-attr]
660673

674+
@property
675+
def parent_header(self):
676+
return self._parent_header.get()
677+
661678
def set_parent(self, parent):
662679
"""Set the parent header for associating output with its triggering input"""
663-
self.parent_header = parent
680+
self._parent_header.set(parent)
664681
self.displayhook.set_parent(parent) # type:ignore[attr-defined]
665682
self.display_pub.set_parent(parent) # type:ignore[attr-defined]
666683
if hasattr(self, "_data_pub"):

tests/test_subshells.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import platform
88
import time
9+
from collections import Counter
910

1011
import pytest
1112
from jupyter_client.blocking.client import BlockingKernelClient
@@ -258,3 +259,53 @@ def test_execute_stop_on_error(are_subshells):
258259
for subshell_id in subshell_ids:
259260
if subshell_id:
260261
delete_subshell_helper(kc, subshell_id)
262+
263+
264+
@pytest.mark.parametrize("are_subshells", [(False, True), (True, False), (True, True)])
265+
def test_idle_message_parent_headers(are_subshells):
266+
with new_kernel() as kc:
267+
# import time module on main shell.
268+
msg = kc.session.msg("execute_request", {"code": "import time"})
269+
kc.shell_channel.send(msg)
270+
271+
subshell_ids = [
272+
create_subshell_helper(kc)["subshell_id"] if is_subshell else None
273+
for is_subshell in are_subshells
274+
]
275+
276+
msg_ids = []
277+
for subshell_id in subshell_ids:
278+
msg = execute_request(kc, "time.sleep(0.5)", subshell_id)
279+
msg_ids.append(msg["msg_id"])
280+
281+
# Expect 4 status messages (2 busy, 2 idle) on iopub channel for the two execute_requests
282+
statuses = []
283+
timeout = TIMEOUT # Combined timeout to receive all the status messages
284+
t0 = time.time()
285+
while True:
286+
status = kc.get_iopub_msg(timeout=timeout)
287+
if status["msg_type"] != "status" or status["parent_header"]["msg_id"] not in msg_ids:
288+
continue
289+
statuses.append(status)
290+
if len(statuses) == 4:
291+
break
292+
t1 = time.time()
293+
timeout -= t1 - t0
294+
t0 = t1
295+
296+
execution_states = Counter(msg["content"]["execution_state"] for msg in statuses)
297+
assert execution_states["busy"] == 2
298+
assert execution_states["idle"] == 2
299+
300+
parent_msg_ids = Counter(msg["parent_header"]["msg_id"] for msg in statuses)
301+
assert parent_msg_ids[msg_ids[0]] == 2
302+
assert parent_msg_ids[msg_ids[1]] == 2
303+
304+
parent_subshell_ids = Counter(msg["parent_header"].get("subshell_id") for msg in statuses)
305+
assert parent_subshell_ids[subshell_ids[0]] == 2
306+
assert parent_subshell_ids[subshell_ids[1]] == 2
307+
308+
# Cleanup
309+
for subshell_id in subshell_ids:
310+
if subshell_id:
311+
delete_subshell_helper(kc, subshell_id)

0 commit comments

Comments
 (0)