Skip to content

Commit 2c80e6c

Browse files
maartenbreddelspre-commit-ci[bot]blink1073
authored
Fix: there can be only one comm_manager (#1049)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Silvester <[email protected]> Fixes #1043
1 parent 4dc3033 commit 2c80e6c

File tree

4 files changed

+101
-7
lines changed

4 files changed

+101
-7
lines changed

ipykernel/comm/comm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def _default_comm_id(self):
7171
def __init__(self, *args, **kwargs):
7272
# Comm takes positional arguments, LoggingConfigurable does not, so we explicitly forward arguments
7373
traitlets.config.LoggingConfigurable.__init__(self, **kwargs)
74+
for name in self.trait_names():
75+
if name in kwargs:
76+
kwargs.pop(name)
7477
BaseComm.__init__(self, *args, **kwargs)
7578

7679

ipykernel/ipkernel.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import getpass
66
import signal
77
import sys
8+
import threading
89
import typing as t
910
from contextlib import contextmanager
1011
from functools import partial
@@ -46,9 +47,19 @@ def _create_comm(*args, **kwargs):
4647
return BaseComm(*args, **kwargs)
4748

4849

50+
# there can only be one comm manager in a ipykernel process
51+
_comm_lock = threading.Lock()
52+
_comm_manager: t.Optional[CommManager] = None
53+
54+
4955
def _get_comm_manager(*args, **kwargs):
5056
"""Create a new CommManager."""
51-
return CommManager(*args, **kwargs)
57+
global _comm_manager
58+
if _comm_manager is None:
59+
with _comm_lock:
60+
if _comm_manager is None:
61+
_comm_manager = CommManager(*args, **kwargs)
62+
return _comm_manager
5263

5364

5465
comm.create_comm = _create_comm

ipykernel/tests/test_comm.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,87 @@
1-
from ipykernel.comm import Comm
1+
from ipykernel.comm import Comm, CommManager
2+
from ipykernel.ipkernel import IPythonKernel
23

34

4-
async def test_comm(kernel):
5-
c = Comm()
6-
c.kernel = kernel # type:ignore
5+
def test_comm(kernel):
6+
manager = CommManager(kernel=kernel)
7+
kernel.comm_manager = manager
8+
9+
c = Comm(kernel=kernel)
10+
msgs = []
11+
12+
def on_close(msg):
13+
msgs.append(msg)
14+
15+
def on_message(msg):
16+
msgs.append(msg)
17+
718
c.publish_msg("foo")
19+
c.open({})
20+
c.on_msg(on_message)
21+
c.on_close(on_close)
22+
c.handle_msg({})
23+
c.handle_close({})
24+
c.close()
25+
assert len(msgs) == 2
26+
27+
28+
def test_comm_manager(kernel):
29+
manager = CommManager(kernel=kernel)
30+
msgs = []
31+
32+
def foo(comm, msg):
33+
msgs.append(msg)
34+
comm.close()
35+
36+
def fizz(comm, msg):
37+
raise RuntimeError('hi')
38+
39+
def on_close(msg):
40+
msgs.append(msg)
41+
42+
def on_msg(msg):
43+
msgs.append(msg)
44+
45+
manager.register_target("foo", foo)
46+
manager.register_target("fizz", fizz)
47+
48+
kernel.comm_manager = manager
49+
comm = Comm()
50+
comm.on_msg(on_msg)
51+
comm.on_close(on_close)
52+
manager.register_comm(comm)
53+
54+
assert manager.get_comm(comm.comm_id) == comm
55+
assert manager.get_comm('foo') is None
56+
57+
msg = dict(content=dict(comm_id=comm.comm_id, target_name='foo'))
58+
manager.comm_open(None, None, msg)
59+
assert len(msgs) == 1
60+
msg['content']['target_name'] = 'bar'
61+
manager.comm_open(None, None, msg)
62+
assert len(msgs) == 1
63+
msg = dict(content=dict(comm_id=comm.comm_id, target_name='fizz'))
64+
manager.comm_open(None, None, msg)
65+
assert len(msgs) == 1
66+
67+
manager.register_comm(comm)
68+
assert manager.get_comm(comm.comm_id) == comm
69+
msg = dict(content=dict(comm_id=comm.comm_id))
70+
manager.comm_msg(None, None, msg)
71+
assert len(msgs) == 2
72+
msg['content']['comm_id'] = 'foo'
73+
manager.comm_msg(None, None, msg)
74+
assert len(msgs) == 2
75+
76+
manager.register_comm(comm)
77+
assert manager.get_comm(comm.comm_id) == comm
78+
msg = dict(content=dict(comm_id=comm.comm_id))
79+
manager.comm_close(None, None, msg)
80+
assert len(msgs) == 3
81+
82+
assert comm._closed
83+
84+
85+
def test_comm_in_manager(ipkernel: IPythonKernel) -> None:
86+
comm = Comm()
87+
assert comm.comm_id in ipkernel.comm_manager.comms

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ requires-python = ">=3.8"
2727
dependencies = [
2828
"debugpy>=1.0",
2929
"ipython>=7.23.1",
30-
"comm>=0.1",
31-
"traitlets>=5.1.0",
30+
"comm>=0.1.1",
31+
"traitlets>=5.4.0",
3232
"jupyter_client>=6.1.12",
3333
"tornado>=6.1",
3434
"matplotlib-inline>=0.1",

0 commit comments

Comments
 (0)