Skip to content

Commit bf10447

Browse files
authored
fix mixture of sync/async sockets in IOPubThread (#1275)
1 parent 8cc1ee3 commit bf10447

File tree

5 files changed

+64
-46
lines changed

5 files changed

+64
-46
lines changed

ipykernel/inprocess/ipkernel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
import sys
88
from contextlib import contextmanager
9+
from typing import cast
910

1011
from anyio import TASK_STATUS_IGNORED
1112
from anyio.abc import TaskStatus
@@ -146,7 +147,8 @@ def callback(msg):
146147
assert frontend is not None
147148
frontend.iopub_channel.call_handlers(msg)
148149

149-
self.iopub_thread.socket.on_recv = callback
150+
iopub_socket = cast(DummySocket, self.iopub_thread.socket)
151+
iopub_socket.on_recv = callback
150152

151153
# ------ Trait initializers -----------------------------------------------
152154

ipykernel/inprocess/socket.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,6 @@ async def poll(self, timeout=0):
6363
assert timeout == 0
6464
statistics = self.in_receive_stream.statistics()
6565
return statistics.current_buffer_used != 0
66+
67+
def close(self):
68+
pass

ipykernel/iostream.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# Copyright (c) IPython Development Team.
44
# Distributed under the terms of the Modified BSD License.
55

6+
from __future__ import annotations
7+
68
import atexit
79
import contextvars
810
import io
@@ -15,7 +17,7 @@
1517
from collections import defaultdict, deque
1618
from io import StringIO, TextIOBase
1719
from threading import Event, Thread, local
18-
from typing import Any, Callable, Optional
20+
from typing import Any, Callable
1921

2022
import zmq
2123
from anyio import create_task_group, run, sleep, to_thread
@@ -25,8 +27,8 @@
2527
# Globals
2628
# -----------------------------------------------------------------------------
2729

28-
MASTER = 0
29-
CHILD = 1
30+
_PARENT = 0
31+
_CHILD = 1
3032

3133
PIPE_BUFFER_SIZE = 1000
3234

@@ -87,9 +89,16 @@ def __init__(self, socket, pipe=False):
8789
Whether this process should listen for IOPub messages
8890
piped from subprocesses.
8991
"""
90-
self.socket = socket
92+
# ensure all of our sockets as sync zmq.Sockets
93+
# don't create async wrappers until we are within the appropriate coroutines
94+
self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket)
95+
if self.socket.context is None:
96+
# bug in pyzmq, shadow socket doesn't always inherit context attribute
97+
self.socket.context = socket.context # type:ignore[unreachable]
98+
self._context = socket.context
99+
91100
self.background_socket = BackgroundSocket(self)
92-
self._master_pid = os.getpid()
101+
self._main_pid = os.getpid()
93102
self._pipe_flag = pipe
94103
if pipe:
95104
self._setup_pipe_in()
@@ -106,8 +115,7 @@ def __init__(self, socket, pipe=False):
106115

107116
def _setup_event_pipe(self):
108117
"""Create the PULL socket listening for events that should fire in this thread."""
109-
ctx = self.socket.context
110-
self._pipe_in0 = ctx.socket(zmq.PULL)
118+
self._pipe_in0 = self._context.socket(zmq.PULL, socket_class=zmq.Socket)
111119
self._pipe_in0.linger = 0
112120

113121
_uuid = b2a_hex(os.urandom(16)).decode("ascii")
@@ -141,8 +149,8 @@ def _event_pipe(self):
141149
event_pipe = self._local.event_pipe
142150
except AttributeError:
143151
# new thread, new event pipe
144-
ctx = zmq.Context(self.socket.context)
145-
event_pipe = ctx.socket(zmq.PUSH)
152+
# create sync base socket
153+
event_pipe = self._context.socket(zmq.PUSH, socket_class=zmq.Socket)
146154
event_pipe.linger = 0
147155
event_pipe.connect(self._event_interface)
148156
self._local.event_pipe = event_pipe
@@ -161,9 +169,11 @@ async def _handle_event(self):
161169
Whenever *an* event arrives on the event stream,
162170
*all* waiting events are processed in order.
163171
"""
172+
# create async wrapper within coroutine
173+
pipe_in = zmq.asyncio.Socket(self._pipe_in0)
164174
try:
165175
while True:
166-
await self._pipe_in0.recv()
176+
await pipe_in.recv()
167177
# freeze event count so new writes don't extend the queue
168178
# while we are processing
169179
n_events = len(self._events)
@@ -177,12 +187,12 @@ async def _handle_event(self):
177187

178188
def _setup_pipe_in(self):
179189
"""setup listening pipe for IOPub from forked subprocesses"""
180-
ctx = self.socket.context
190+
ctx = self._context
181191

182192
# use UUID to authenticate pipe messages
183193
self._pipe_uuid = os.urandom(16)
184194

185-
self._pipe_in1 = ctx.socket(zmq.PULL)
195+
self._pipe_in1 = ctx.socket(zmq.PULL, socket_class=zmq.Socket)
186196
self._pipe_in1.linger = 0
187197

188198
try:
@@ -199,6 +209,8 @@ def _setup_pipe_in(self):
199209

200210
async def _handle_pipe_msgs(self):
201211
"""handle pipe messages from a subprocess"""
212+
# create async wrapper within coroutine
213+
self._async_pipe_in1 = zmq.asyncio.Socket(self._pipe_in1)
202214
try:
203215
while True:
204216
await self._handle_pipe_msg()
@@ -209,8 +221,8 @@ async def _handle_pipe_msgs(self):
209221

210222
async def _handle_pipe_msg(self, msg=None):
211223
"""handle a pipe message from a subprocess"""
212-
msg = msg or await self._pipe_in1.recv_multipart()
213-
if not self._pipe_flag or not self._is_master_process():
224+
msg = msg or await self._async_pipe_in1.recv_multipart()
225+
if not self._pipe_flag or not self._is_main_process():
214226
return
215227
if msg[0] != self._pipe_uuid:
216228
print("Bad pipe message: %s", msg, file=sys.__stderr__)
@@ -225,14 +237,14 @@ def _setup_pipe_out(self):
225237
pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port)
226238
return ctx, pipe_out
227239

228-
def _is_master_process(self):
229-
return os.getpid() == self._master_pid
240+
def _is_main_process(self):
241+
return os.getpid() == self._main_pid
230242

231243
def _check_mp_mode(self):
232244
"""check for forks, and switch to zmq pipeline if necessary"""
233-
if not self._pipe_flag or self._is_master_process():
234-
return MASTER
235-
return CHILD
245+
if not self._pipe_flag or self._is_main_process():
246+
return _PARENT
247+
return _CHILD
236248

237249
def start(self):
238250
"""Start the IOPub thread"""
@@ -265,7 +277,8 @@ def close(self):
265277
self._pipe_in0.close()
266278
if self._pipe_flag:
267279
self._pipe_in1.close()
268-
self.socket.close()
280+
if self.socket is not None:
281+
self.socket.close()
269282
self.socket = None
270283

271284
@property
@@ -301,12 +314,12 @@ def _really_send(self, msg, *args, **kwargs):
301314
return
302315

303316
mp_mode = self._check_mp_mode()
304-
305-
if mp_mode != CHILD:
306-
# we are master, do a regular send
317+
if mp_mode != _CHILD:
318+
# we are the main parent process, do a regular send
319+
assert self.socket is not None
307320
self.socket.send_multipart(msg, *args, **kwargs)
308321
else:
309-
# we are a child, pipe to master
322+
# we are a child, pipe to parent process
310323
# new context/socket for every pipe-out
311324
# since forks don't teardown politely, use ctx.term to ensure send has completed
312325
ctx, pipe_out = self._setup_pipe_out()
@@ -379,7 +392,7 @@ class OutStream(TextIOBase):
379392
flush_interval = 0.2
380393
topic = None
381394
encoding = "UTF-8"
382-
_exc: Optional[Any] = None
395+
_exc: Any = None
383396

384397
def fileno(self):
385398
"""
@@ -477,7 +490,7 @@ def __init__(
477490
self._thread_to_parent = {}
478491
self._thread_to_parent_header = {}
479492
self._parent_header_global = {}
480-
self._master_pid = os.getpid()
493+
self._main_pid = os.getpid()
481494
self._flush_pending = False
482495
self._subprocess_flush_pending = False
483496
self._buffer_lock = threading.RLock()
@@ -569,8 +582,8 @@ def _setup_stream_redirects(self, name):
569582
self.watch_fd_thread.daemon = True
570583
self.watch_fd_thread.start()
571584

572-
def _is_master_process(self):
573-
return os.getpid() == self._master_pid
585+
def _is_main_process(self):
586+
return os.getpid() == self._main_pid
574587

575588
def set_parent(self, parent):
576589
"""Set the parent header."""
@@ -674,7 +687,7 @@ def _flush(self):
674687
ident=self.topic,
675688
)
676689

677-
def write(self, string: str) -> Optional[int]: # type:ignore[override]
690+
def write(self, string: str) -> int:
678691
"""Write to current stream after encoding if necessary
679692
680693
Returns
@@ -700,15 +713,15 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
700713
msg = "I/O operation on closed file"
701714
raise ValueError(msg)
702715

703-
is_child = not self._is_master_process()
716+
is_child = not self._is_main_process()
704717
# only touch the buffer in the IO thread to avoid races
705718
with self._buffer_lock:
706719
self._buffers[frozenset(parent.items())].write(string)
707720
if is_child:
708721
# mp.Pool cannot be trusted to flush promptly (or ever),
709722
# and this helps.
710723
if self._subprocess_flush_pending:
711-
return None
724+
return 0
712725
self._subprocess_flush_pending = True
713726
# We can not rely on self._io_loop.call_later from a subprocess
714727
self.pub_thread.schedule(self._flush)

tests/test_io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import zmq.asyncio
1616
from jupyter_client.session import Session
1717

18-
from ipykernel.iostream import MASTER, BackgroundSocket, IOPubThread, OutStream
18+
from ipykernel.iostream import _PARENT, BackgroundSocket, IOPubThread, OutStream
1919

2020

2121
@pytest.fixture()
@@ -73,7 +73,7 @@ async def test_io_thread(anyio_backend, iopub_thread):
7373
ctx1, pipe = thread._setup_pipe_out()
7474
pipe.close()
7575
thread._pipe_in1.close()
76-
thread._check_mp_mode = lambda: MASTER
76+
thread._check_mp_mode = lambda: _PARENT
7777
thread._really_send([b"hi"])
7878
ctx1.destroy()
7979
thread.stop()

tests/test_kernel.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
)
3333

3434

35-
def _check_master(kc, expected=True, stream="stdout"):
35+
def _check_main(kc, expected=True, stream="stdout"):
3636
execute(kc=kc, code="import sys")
3737
flush_channels(kc)
38-
msg_id, content = execute(kc=kc, code="print(sys.%s._is_master_process())" % stream)
38+
msg_id, content = execute(kc=kc, code="print(sys.%s._is_main_process())" % stream)
3939
stdout, stderr = assemble_output(kc.get_iopub_msg)
4040
assert stdout.strip() == repr(expected)
4141

@@ -56,7 +56,7 @@ def test_simple_print():
5656
stdout, stderr = assemble_output(kc.get_iopub_msg)
5757
assert stdout == "hi\n"
5858
assert stderr == ""
59-
_check_master(kc, expected=True)
59+
_check_main(kc, expected=True)
6060

6161

6262
def test_print_to_correct_cell_from_thread():
@@ -168,7 +168,7 @@ def test_capture_fd():
168168
stdout, stderr = assemble_output(iopub)
169169
assert stdout == "capsys\n"
170170
assert stderr == ""
171-
_check_master(kc, expected=True)
171+
_check_main(kc, expected=True)
172172

173173

174174
@pytest.mark.skip(reason="Currently don't capture during test as pytest does its own capturing")
@@ -182,7 +182,7 @@ def test_subprocess_peek_at_stream_fileno():
182182
stdout, stderr = assemble_output(iopub)
183183
assert stdout == "CAP1\nCAP2\n"
184184
assert stderr == ""
185-
_check_master(kc, expected=True)
185+
_check_main(kc, expected=True)
186186

187187

188188
def test_sys_path():
@@ -218,7 +218,7 @@ def test_sys_path_profile_dir():
218218
def test_subprocess_print():
219219
"""printing from forked mp.Process"""
220220
with new_kernel() as kc:
221-
_check_master(kc, expected=True)
221+
_check_main(kc, expected=True)
222222
flush_channels(kc)
223223
np = 5
224224
code = "\n".join(
@@ -238,8 +238,8 @@ def test_subprocess_print():
238238
for n in range(np):
239239
assert stdout.count(str(n)) == 1, stdout
240240
assert stderr == ""
241-
_check_master(kc, expected=True)
242-
_check_master(kc, expected=True, stream="stderr")
241+
_check_main(kc, expected=True)
242+
_check_main(kc, expected=True, stream="stderr")
243243

244244

245245
@flaky(max_runs=3)
@@ -261,8 +261,8 @@ def test_subprocess_noprint():
261261
assert stdout == ""
262262
assert stderr == ""
263263

264-
_check_master(kc, expected=True)
265-
_check_master(kc, expected=True, stream="stderr")
264+
_check_main(kc, expected=True)
265+
_check_main(kc, expected=True, stream="stderr")
266266

267267

268268
@flaky(max_runs=3)
@@ -287,8 +287,8 @@ def test_subprocess_error():
287287
assert stdout == ""
288288
assert "ValueError" in stderr
289289

290-
_check_master(kc, expected=True)
291-
_check_master(kc, expected=True, stream="stderr")
290+
_check_main(kc, expected=True)
291+
_check_main(kc, expected=True, stream="stderr")
292292

293293

294294
# raw_input tests

0 commit comments

Comments
 (0)