diff --git a/docs/api/ipykernel.rst b/docs/api/ipykernel.rst index dd46d0842..27b25c893 100644 --- a/docs/api/ipykernel.rst +++ b/docs/api/ipykernel.rst @@ -134,12 +134,6 @@ Submodules :show-inheritance: -.. automodule:: ipykernel.trio_runner - :members: - :undoc-members: - :show-inheritance: - - .. automodule:: ipykernel.zmqshell :members: :undoc-members: diff --git a/ipykernel/debugger.py b/ipykernel/debugger.py index 57804d2db..bb8f338ac 100644 --- a/ipykernel/debugger.py +++ b/ipykernel/debugger.py @@ -242,7 +242,7 @@ async def _send_request(self, msg): self.log.debug("DEBUGPYCLIENT:") self.log.debug(self.routing_id) self.log.debug(buf) - await self.debugpy_socket.send_multipart((self.routing_id, buf)) + await self.debugpy_socket.asend_multipart((self.routing_id, buf)).wait() async def _wait_for_response(self): # Since events are never pushed to the message_queue @@ -438,7 +438,7 @@ async def start(self): (self.shell_socket.getsockopt(ROUTING_ID)), ) - msg = await self.shell_socket.recv_multipart() + msg = await self.shell_socket.arecv_multipart().wait() ident, msg = self.session.feed_identities(msg, copy=True) try: msg = self.session.deserialize(msg, content=True, copy=True) diff --git a/ipykernel/heartbeat.py b/ipykernel/heartbeat.py index 7706312e1..1340ccfc6 100644 --- a/ipykernel/heartbeat.py +++ b/ipykernel/heartbeat.py @@ -92,13 +92,17 @@ def _bind_socket(self): def run(self): """Run the heartbeat thread.""" self.name = "Heartbeat" - self.socket = self.context.socket(zmq.ROUTER) - self.socket.linger = 1000 + try: + self.socket = self.context.socket(zmq.ROUTER) + self.socket.linger = 1000 self._bind_socket() except Exception: - self.socket.close() - raise + try: + self.socket.close() + except Exception: + pass + return while True: try: diff --git a/ipykernel/inprocess/ipkernel.py b/ipykernel/inprocess/ipkernel.py index c6f8c6128..efaa594bd 100644 --- a/ipykernel/inprocess/ipkernel.py +++ b/ipykernel/inprocess/ipkernel.py @@ -54,7 +54,7 @@ class InProcessKernel(IPythonKernel): _underlying_iopub_socket = Instance(DummySocket, (False,)) iopub_thread: IOPubThread = Instance(IOPubThread) # type:ignore[assignment] - shell_socket = Instance(DummySocket, (True,)) # type:ignore[arg-type] + shell_socket = Instance(DummySocket, (True,)) @default("iopub_thread") def _default_iopub_thread(self): @@ -207,7 +207,7 @@ def enable_pylab(self, gui=None, import_all=True, welcome_message=False): """Activate pylab support at runtime.""" if not gui: gui = self.kernel.gui - return super().enable_pylab(gui, import_all, welcome_message) + return super().enable_pylab(gui, import_all, welcome_message) # type: ignore[call-arg] InteractiveShellABC.register(InProcessInteractiveShell) diff --git a/ipykernel/inprocess/session.py b/ipykernel/inprocess/session.py index 390ac9954..f8634a36d 100644 --- a/ipykernel/inprocess/session.py +++ b/ipykernel/inprocess/session.py @@ -12,7 +12,7 @@ async def recv( # type: ignore[override] mode, content, copy have no effect, but are present for superclass compatibility """ - return await socket.recv_multipart() + return await socket.arecv_multipart().wait() def send( self, diff --git a/ipykernel/inprocess/socket.py b/ipykernel/inprocess/socket.py index 05b45687c..10204c97c 100644 --- a/ipykernel/inprocess/socket.py +++ b/ipykernel/inprocess/socket.py @@ -65,4 +65,8 @@ async def poll(self, timeout=0): return statistics.current_buffer_used != 0 def close(self): - pass + if self.is_shell: + self.in_send_stream.close() + self.in_receive_stream.close() + self.out_send_stream.close() + self.out_receive_stream.close() diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index 8cec0f42d..b98cbd988 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -20,6 +20,7 @@ from typing import Any, Callable import zmq +import zmq_anyio from anyio import sleep from jupyter_client.session import extract_header @@ -48,7 +49,7 @@ class IOPubThread: whose IO is always run in a thread. """ - def __init__(self, socket, pipe=False): + def __init__(self, socket: zmq_anyio.Socket, pipe: bool = False): """Create IOPub thread Parameters @@ -61,10 +62,7 @@ def __init__(self, socket, pipe=False): """ # ensure all of our sockets as sync zmq.Sockets # don't create async wrappers until we are within the appropriate coroutines - self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket) - if self.socket.context is None: - # bug in pyzmq, shadow socket doesn't always inherit context attribute - self.socket.context = socket.context # type:ignore[unreachable] + self.socket: zmq_anyio.Socket = socket self._context = socket.context self.background_socket = BackgroundSocket(self) @@ -78,7 +76,7 @@ def __init__(self, socket, pipe=False): self._event_pipe_gc_lock: threading.Lock = threading.Lock() self._event_pipe_gc_seconds: float = 10 self._setup_event_pipe() - tasks = [self._handle_event, self._run_event_pipe_gc] + tasks = [self._handle_event, self._run_event_pipe_gc, self.socket.start] if pipe: tasks.append(self._handle_pipe_msgs) self.thread = BaseThread(name="IOPub", daemon=True) @@ -87,7 +85,7 @@ def __init__(self, socket, pipe=False): def _setup_event_pipe(self): """Create the PULL socket listening for events that should fire in this thread.""" - self._pipe_in0 = self._context.socket(zmq.PULL, socket_class=zmq.Socket) + self._pipe_in0 = self._context.socket(zmq.PULL) self._pipe_in0.linger = 0 _uuid = b2a_hex(os.urandom(16)).decode("ascii") @@ -99,11 +97,11 @@ async def _run_event_pipe_gc(self): while True: await sleep(self._event_pipe_gc_seconds) try: - await self._event_pipe_gc() + self._event_pipe_gc() except Exception as e: print(f"Exception in IOPubThread._event_pipe_gc: {e}", file=sys.__stderr__) - async def _event_pipe_gc(self): + def _event_pipe_gc(self): """run a single garbage collection on event pipes""" if not self._event_pipes: # don't acquire the lock if there's nothing to do @@ -122,7 +120,7 @@ def _event_pipe(self): except AttributeError: # new thread, new event pipe # create sync base socket - event_pipe = self._context.socket(zmq.PUSH, socket_class=zmq.Socket) + event_pipe = self._context.socket(zmq.PUSH) event_pipe.linger = 0 event_pipe.connect(self._event_interface) self._local.event_pipe = event_pipe @@ -141,30 +139,28 @@ async def _handle_event(self): Whenever *an* event arrives on the event stream, *all* waiting events are processed in order. """ - # create async wrapper within coroutine - pipe_in = zmq.asyncio.Socket(self._pipe_in0) - try: - while True: - await pipe_in.recv() - # freeze event count so new writes don't extend the queue - # while we are processing - n_events = len(self._events) - for _ in range(n_events): - event_f = self._events.popleft() - event_f() - except Exception: - if self.thread.stopped.is_set(): - return - raise + pipe_in = zmq_anyio.Socket(self._pipe_in0) + async with pipe_in: + try: + while True: + await pipe_in.arecv().wait() + # freeze event count so new writes don't extend the queue + # while we are processing + n_events = len(self._events) + for _ in range(n_events): + event_f = self._events.popleft() + event_f() + except Exception: + if self.thread.stopped.is_set(): + return + raise def _setup_pipe_in(self): """setup listening pipe for IOPub from forked subprocesses""" - ctx = self._context - # use UUID to authenticate pipe messages self._pipe_uuid = os.urandom(16) - self._pipe_in1 = ctx.socket(zmq.PULL, socket_class=zmq.Socket) + self._pipe_in1 = zmq_anyio.Socket(self._context.socket(zmq.PULL)) self._pipe_in1.linger = 0 try: @@ -181,19 +177,18 @@ def _setup_pipe_in(self): async def _handle_pipe_msgs(self): """handle pipe messages from a subprocess""" - # create async wrapper within coroutine - self._async_pipe_in1 = zmq.asyncio.Socket(self._pipe_in1) - try: - while True: - await self._handle_pipe_msg() - except Exception: - if self.thread.stopped.is_set(): - return - raise + async with self._pipe_in1: + try: + while True: + await self._handle_pipe_msg() + except Exception: + if self.thread.stopped.is_set(): + return + raise async def _handle_pipe_msg(self, msg=None): """handle a pipe message from a subprocess""" - msg = msg or await self._async_pipe_in1.recv_multipart() + msg = msg or await self._pipe_in1.arecv_multipart().wait() if not self._pipe_flag or not self._is_main_process(): return if msg[0] != self._pipe_uuid: @@ -246,7 +241,10 @@ def close(self): """Close the IOPub thread.""" if self.closed: return - self._pipe_in0.close() + try: + self._pipe_in0.close() + except Exception: + pass if self._pipe_flag: self._pipe_in1.close() if self.socket is not None: diff --git a/ipykernel/ipkernel.py b/ipykernel/ipkernel.py index 5ba500198..b170d55e7 100644 --- a/ipykernel/ipkernel.py +++ b/ipykernel/ipkernel.py @@ -12,7 +12,7 @@ from dataclasses import dataclass import comm -import zmq.asyncio +import zmq_anyio from anyio import TASK_STATUS_IGNORED, create_task_group, to_thread from anyio.abc import TaskStatus from IPython.core import release @@ -93,7 +93,7 @@ class IPythonKernel(KernelBase): help="Set this flag to False to deactivate the use of experimental IPython completion APIs.", ).tag(config=True) - debugpy_socket = Instance(zmq.asyncio.Socket, allow_none=True) + debugpy_socket = Instance(zmq_anyio.Socket, allow_none=True) user_module = Any() @@ -229,7 +229,8 @@ def __init__(self, **kwargs): } async def process_debugpy(self): - async with create_task_group() as tg: + assert self.debugpy_socket is not None + async with self.debug_shell_socket, self.debugpy_socket, create_task_group() as tg: tg.start_soon(self.receive_debugpy_messages) tg.start_soon(self.poll_stopped_queue) await to_thread.run_sync(self.debugpy_stop.wait) @@ -252,7 +253,7 @@ async def receive_debugpy_message(self, msg=None): if msg is None: assert self.debugpy_socket is not None - msg = await self.debugpy_socket.recv_multipart() + msg = await self.debugpy_socket.arecv_multipart().wait() # The first frame is the socket id, we can drop it frame = msg[1].decode("utf-8") self.log.debug("Debugpy received: %s", frame) diff --git a/ipykernel/kernelapp.py b/ipykernel/kernelapp.py index 676d2d46f..8078f97ce 100644 --- a/ipykernel/kernelapp.py +++ b/ipykernel/kernelapp.py @@ -19,7 +19,7 @@ from typing import Optional import zmq -import zmq.asyncio +import zmq_anyio from anyio import create_task_group, run, to_thread from IPython.core.application import ( # type:ignore[attr-defined] BaseIPythonApplication, @@ -333,15 +333,15 @@ def init_sockets(self): """Create a context, a session, and the kernel sockets.""" self.log.info("Starting the kernel at pid: %i", os.getpid()) assert self.context is None, "init_sockets cannot be called twice!" - self.context = context = zmq.asyncio.Context() + self.context = context = zmq.Context() atexit.register(self.close) - self.shell_socket = context.socket(zmq.ROUTER) + self.shell_socket = zmq_anyio.Socket(context.socket(zmq.ROUTER)) self.shell_socket.linger = 1000 self.shell_port = self._bind_socket(self.shell_socket, self.shell_port) self.log.debug("shell ROUTER Channel on port: %i", self.shell_port) - self.stdin_socket = zmq.Context(context).socket(zmq.ROUTER) + self.stdin_socket = context.socket(zmq.ROUTER) self.stdin_socket.linger = 1000 self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port) self.log.debug("stdin ROUTER Channel on port: %i", self.stdin_port) @@ -357,18 +357,19 @@ def init_sockets(self): def init_control(self, context): """Initialize the control channel.""" - self.control_socket = context.socket(zmq.ROUTER) + self.control_socket = zmq_anyio.Socket(context.socket(zmq.ROUTER)) self.control_socket.linger = 1000 self.control_port = self._bind_socket(self.control_socket, self.control_port) self.log.debug("control ROUTER Channel on port: %i", self.control_port) - self.debugpy_socket = context.socket(zmq.STREAM) + self.debugpy_socket = zmq_anyio.Socket(context.socket(zmq.STREAM)) self.debugpy_socket.linger = 1000 - self.debug_shell_socket = context.socket(zmq.DEALER) + self.debug_shell_socket = zmq_anyio.Socket(context.socket(zmq.DEALER)) self.debug_shell_socket.linger = 1000 - if self.shell_socket.getsockopt(zmq.LAST_ENDPOINT): - self.debug_shell_socket.connect(self.shell_socket.getsockopt(zmq.LAST_ENDPOINT)) + last_endpoint = self.shell_socket.getsockopt(zmq.LAST_ENDPOINT) + if last_endpoint: + self.debug_shell_socket.connect(last_endpoint) if hasattr(zmq, "ROUTER_HANDOVER"): # set router-handover to workaround zeromq reconnect problems @@ -381,7 +382,7 @@ def init_control(self, context): def init_iopub(self, context): """Initialize the iopub channel.""" - self.iopub_socket = context.socket(zmq.PUB) + self.iopub_socket = zmq_anyio.Socket(context.socket(zmq.PUB)) self.iopub_socket.linger = 1000 self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port) self.log.debug("iopub PUB Channel on port: %i", self.iopub_port) @@ -679,43 +680,6 @@ def configure_tornado_logger(self): handler.setFormatter(formatter) logger.addHandler(handler) - def _init_asyncio_patch(self): - """set default asyncio policy to be compatible with tornado - - Tornado 6 (at least) is not compatible with the default - asyncio implementation on Windows - - Pick the older SelectorEventLoopPolicy on Windows - if the known-incompatible default policy is in use. - - Support for Proactor via a background thread is available in tornado 6.1, - but it is still preferable to run the Selector in the main thread - instead of the background. - - do this as early as possible to make it a low priority and overridable - - ref: https://github.com/tornadoweb/tornado/issues/2608 - - FIXME: if/when tornado supports the defaults in asyncio without threads, - remove and bump tornado requirement for py38. - Most likely, this will mean a new Python version - where asyncio.ProactorEventLoop supports add_reader and friends. - - """ - if sys.platform.startswith("win"): - import asyncio - - try: - from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy - except ImportError: - pass - # not affected - else: - if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy: - # WindowsProactorEventLoopPolicy is not compatible with tornado 6 - # fallback to the pre-3.8 default of Selector - asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy()) - def init_pdb(self): """Replace pdb with IPython's version that is interruptible. @@ -735,7 +699,6 @@ def init_pdb(self): @catch_config_error def initialize(self, argv=None) -> None: """Initialize the application.""" - self._init_asyncio_patch() super().initialize(argv) if self.subapp is not None: return @@ -772,7 +735,7 @@ def initialize(self, argv=None) -> None: sys.stdout.flush() sys.stderr.flush() - async def _start(self) -> None: + async def _start(self, backend: str) -> None: """ Async version of start, when the loop is not controlled by IPykernel @@ -783,12 +746,23 @@ async def _start(self) -> None: return if self.poller is not None: self.poller.start() + + if backend == "asyncio" and sys.platform == "win32": + import asyncio + + policy = asyncio.get_event_loop_policy() + if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy": + from anyio._core._asyncio_selector_thread import get_selector + + selector = get_selector() + selector._thread.pydev_do_not_trace = True + await self.main() def start(self) -> None: """Start the application.""" backend = "trio" if self.trio_loop else "asyncio" - run(self._start, backend=backend) + run(partial(self._start, backend), backend=backend) async def _wait_to_enter_eventloop(self) -> None: await to_thread.run_sync(self.kernel._eventloop_set.wait) diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index 08f60e14e..6465751c2 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -36,6 +36,7 @@ import psutil import zmq +import zmq_anyio from anyio import TASK_STATUS_IGNORED, create_task_group, sleep, to_thread from anyio.abc import TaskStatus from IPython.core.error import StdinNotImplementedError @@ -97,7 +98,7 @@ class Kernel(SingletonConfigurable): session = Instance(Session, allow_none=True) profile_dir = Instance("IPython.core.profiledir.ProfileDir", allow_none=True) - shell_socket = Instance(zmq.asyncio.Socket, allow_none=True) + shell_socket = Instance(zmq_anyio.Socket, allow_none=True) implementation: str implementation_version: str @@ -105,7 +106,7 @@ class Kernel(SingletonConfigurable): _is_test = Bool(False) - control_socket = Instance(zmq.asyncio.Socket, allow_none=True) + control_socket = Instance(zmq_anyio.Socket, allow_none=True) control_tasks: t.Any = List() debug_shell_socket = Any() @@ -278,7 +279,7 @@ async def process_control_message(self, msg=None): assert self.session is not None assert self.control_thread is None or threading.current_thread() == self.control_thread - msg = msg or await self.control_socket.recv_multipart() + msg = msg or await self.control_socket.arecv_multipart().wait() idents, msg = self.session.feed_identities(msg) try: msg = self.session.deserialize(msg, content=True) @@ -375,26 +376,31 @@ async def shell_channel_thread_main(self): assert self.shell_channel_thread is not None assert threading.current_thread() == self.shell_channel_thread - try: - while True: - msg = await self.shell_socket.recv_multipart(copy=False) - # deserialize only the header to get subshell_id - # Keep original message to send to subshell_id unmodified. - _, msg2 = self.session.feed_identities(msg, copy=False) - try: - msg3 = self.session.deserialize(msg2, content=False, copy=False) - subshell_id = msg3["header"].get("subshell_id") - - # Find inproc pair socket to use to send message to correct subshell. - socket = self.shell_channel_thread.manager.get_shell_channel_socket(subshell_id) - assert socket is not None - socket.send_multipart(msg, copy=False) - except Exception: - self.log.error("Invalid message", exc_info=True) # noqa: G201 - except BaseException: - if self.shell_stop.is_set(): - return - raise + async with self.shell_socket, create_task_group() as tg: + try: + while True: + msg = await self.shell_socket.arecv_multipart(copy=False).wait() + # deserialize only the header to get subshell_id + # Keep original message to send to subshell_id unmodified. + _, msg2 = self.session.feed_identities(msg, copy=False) + try: + msg3 = self.session.deserialize(msg2, content=False, copy=False) + subshell_id = msg3["header"].get("subshell_id") + + # Find inproc pair socket to use to send message to correct subshell. + socket = self.shell_channel_thread.manager.get_shell_channel_socket( + subshell_id + ) + assert socket is not None + if not socket.started.is_set(): + await tg.start(socket.start) + await socket.asend_multipart(msg, copy=False).wait() + except Exception: + self.log.error("Invalid message", exc_info=True) # noqa: G201 + except BaseException: + if self.shell_stop.is_set(): + return + raise async def shell_main(self, subshell_id: str | None): """Main loop for a single subshell.""" @@ -414,6 +420,8 @@ async def shell_main(self, subshell_id: str | None): socket = None async with create_task_group() as tg: + if not socket.started.is_set(): + await tg.start(socket.start) tg.start_soon(self.process_shell, socket) if subshell_id is None: # Main subshell. @@ -446,8 +454,8 @@ async def process_shell_message(self, msg=None, socket=None): assert socket is None socket = self.shell_socket - no_msg = msg is None if self._is_test else not await socket.poll(0) - msg = msg or await socket.recv_multipart(copy=False) + no_msg = msg is None if self._is_test else not await socket.apoll(0).wait() + msg = msg or await socket.arecv_multipart(copy=False).wait() received_time = time.monotonic() copy = not isinstance(msg[0], zmq.Message) @@ -520,7 +528,8 @@ async def process_shell_message(self, msg=None, socket=None): self._publish_status("idle", "shell") async def control_main(self): - async with create_task_group() as tg: + assert self.control_socket is not None + async with self.control_socket, create_task_group() as tg: for task in self.control_tasks: tg.start_soon(task) tg.start_soon(self.process_control) @@ -557,7 +566,7 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: manager = self.shell_channel_thread.manager self.shell_channel_thread.start_soon(self.shell_channel_thread_main) self.shell_channel_thread.start_soon( - partial(manager.listen_from_control, self.shell_main) + partial(manager.listen_from_control, self.shell_main, self.shell_channel_thread) ) self.shell_channel_thread.start_soon(manager.listen_from_subshells) self.shell_channel_thread.start() @@ -1137,9 +1146,11 @@ async def create_subshell_request(self, socket, ident, parent) -> None: # This should only be called in the control thread if it exists. # Request is passed to shell channel thread to process. - other_socket = self.shell_channel_thread.manager.get_control_other_socket() - await other_socket.send_json({"type": "create"}) - reply = await other_socket.recv_json() + other_socket = await self.shell_channel_thread.manager.get_control_other_socket( + self.control_thread + ) + await other_socket.asend_json({"type": "create"}).wait() + reply = await other_socket.arecv_json().wait() self.session.send(socket, "create_subshell_reply", reply, parent, ident) @@ -1159,9 +1170,11 @@ async def delete_subshell_request(self, socket, ident, parent) -> None: # This should only be called in the control thread if it exists. # Request is passed to shell channel thread to process. - other_socket = self.shell_channel_thread.manager.get_control_other_socket() - await other_socket.send_json({"type": "delete", "subshell_id": subshell_id}) - reply = await other_socket.recv_json() + other_socket = await self.shell_channel_thread.manager.get_control_other_socket( + self.control_thread + ) + await other_socket.asend_json({"type": "delete", "subshell_id": subshell_id}).wait() + reply = await other_socket.arecv_json().wait() self.session.send(socket, "delete_subshell_reply", reply, parent, ident) @@ -1174,9 +1187,11 @@ async def list_subshell_request(self, socket, ident, parent) -> None: # This should only be called in the control thread if it exists. # Request is passed to shell channel thread to process. - other_socket = self.shell_channel_thread.manager.get_control_other_socket() - await other_socket.send_json({"type": "list"}) - reply = await other_socket.recv_json() + other_socket = await self.shell_channel_thread.manager.get_control_other_socket( + self.control_thread + ) + await other_socket.asend_json({"type": "list"}).wait() + reply = await other_socket.arecv_json().wait() self.session.send(socket, "list_subshell_reply", reply, parent, ident) diff --git a/ipykernel/shellchannel.py b/ipykernel/shellchannel.py index 10abdb359..30ea6437b 100644 --- a/ipykernel/shellchannel.py +++ b/ipykernel/shellchannel.py @@ -1,6 +1,7 @@ """A thread for a shell channel.""" -import zmq.asyncio +import zmq +import zmq_anyio from .subshell_manager import SubshellManager from .thread import SHELL_CHANNEL_THREAD_NAME, BaseThread @@ -12,7 +13,12 @@ class ShellChannelThread(BaseThread): Communicates with shell/subshell threads via pairs of ZMQ inproc sockets. """ - def __init__(self, context: zmq.asyncio.Context, shell_socket: zmq.asyncio.Socket, **kwargs): + def __init__( + self, + context: zmq.Context, # type: ignore[type-arg] + shell_socket: zmq_anyio.Socket, + **kwargs, + ): """Initialize the thread.""" super().__init__(name=SHELL_CHANNEL_THREAD_NAME, **kwargs) self._manager: SubshellManager | None = None diff --git a/ipykernel/subshell.py b/ipykernel/subshell.py index 18e15ab38..180e9ecb3 100644 --- a/ipykernel/subshell.py +++ b/ipykernel/subshell.py @@ -2,7 +2,8 @@ from threading import current_thread -import zmq.asyncio +import zmq +import zmq_anyio from .thread import BaseThread @@ -15,17 +16,22 @@ def __init__(self, subshell_id: str, **kwargs): super().__init__(name=f"subshell-{subshell_id}", **kwargs) # Inproc PAIR socket, for communication with shell channel thread. - self._pair_socket: zmq.asyncio.Socket | None = None + self._pair_socket: zmq_anyio.Socket | None = None - async def create_pair_socket(self, context: zmq.asyncio.Context, address: str) -> None: + async def create_pair_socket( + self, + context: zmq.Context, # type: ignore[type-arg] + address: str, + ) -> None: """Create inproc PAIR socket, for communication with shell channel thread. - Should be called from this thread, so usually via add_task before the + Should be called from this thread, so usually via start_soon before the thread is started. """ assert current_thread() == self - self._pair_socket = context.socket(zmq.PAIR) + self._pair_socket = zmq_anyio.Socket(context, zmq.PAIR) self._pair_socket.connect(address) + self.start_soon(self._pair_socket.start) def run(self) -> None: try: diff --git a/ipykernel/subshell_manager.py b/ipykernel/subshell_manager.py index 14c4c57c3..f4f92d2da 100644 --- a/ipykernel/subshell_manager.py +++ b/ipykernel/subshell_manager.py @@ -11,17 +11,18 @@ from threading import Lock, current_thread, main_thread import zmq -import zmq.asyncio +import zmq_anyio from anyio import create_memory_object_stream, create_task_group +from anyio.abc import TaskGroup from .subshell import SubshellThread -from .thread import SHELL_CHANNEL_THREAD_NAME +from .thread import SHELL_CHANNEL_THREAD_NAME, BaseThread @dataclass class Subshell: thread: SubshellThread - shell_channel_socket: zmq.asyncio.Socket + shell_channel_socket: zmq_anyio.Socket class SubshellManager: @@ -39,10 +40,14 @@ class SubshellManager: against multiple subshells attempting to send at the same time. """ - def __init__(self, context: zmq.asyncio.Context, shell_socket: zmq.asyncio.Socket): + def __init__( + self, + context: zmq.Context, # type: ignore[type-arg] + shell_socket: zmq_anyio.Socket, + ): assert current_thread() == main_thread() - self._context: zmq.asyncio.Context = context + self._context: zmq.Context = context # type: ignore[type-arg] self._shell_socket = shell_socket self._cache: dict[str, Subshell] = {} self._lock_cache = Lock() @@ -51,15 +56,39 @@ def __init__(self, context: zmq.asyncio.Context, shell_socket: zmq.asyncio.Socke # Inproc pair sockets for control channel and main shell (parent subshell). # Each inproc pair has a "shell_channel" socket used in the shell channel # thread, and an "other" socket used in the other thread. - self._control_shell_channel_socket = self._create_inproc_pair_socket("control", True) - self._control_other_socket = self._create_inproc_pair_socket("control", False) - self._parent_shell_channel_socket = self._create_inproc_pair_socket(None, True) - self._parent_other_socket = self._create_inproc_pair_socket(None, False) + self.__control_shell_channel_socket: zmq_anyio.Socket | None = None + self.__control_other_socket: zmq_anyio.Socket | None = None + self.__parent_shell_channel_socket: zmq_anyio.Socket | None = None + self.__parent_other_socket: zmq_anyio.Socket | None = None # anyio memory object stream for async queue-like communication between tasks. # Used by _create_subshell to tell listen_from_subshells to spawn a new task. self._send_stream, self._receive_stream = create_memory_object_stream[str]() + @property + def _control_shell_channel_socket(self) -> zmq_anyio.Socket: + if self.__control_shell_channel_socket is None: + self.__control_shell_channel_socket = self._create_inproc_pair_socket("control", True) + return self.__control_shell_channel_socket + + @property + def _control_other_socket(self) -> zmq_anyio.Socket: + if self.__control_other_socket is None: + self.__control_other_socket = self._create_inproc_pair_socket("control", False) + return self.__control_other_socket + + @property + def _parent_shell_channel_socket(self) -> zmq_anyio.Socket: + if self.__parent_shell_channel_socket is None: + self.__parent_shell_channel_socket = self._create_inproc_pair_socket(None, True) + return self.__parent_shell_channel_socket + + @property + def _parent_other_socket(self) -> zmq_anyio.Socket: + if self.__parent_other_socket is None: + self.__parent_other_socket = self._create_inproc_pair_socket(None, False) + return self.__parent_other_socket + def close(self) -> None: """Stop all subshells and close all resources.""" assert current_thread().name == SHELL_CHANNEL_THREAD_NAME @@ -68,10 +97,10 @@ def close(self) -> None: self._receive_stream.close() for socket in ( - self._control_shell_channel_socket, - self._control_other_socket, - self._parent_shell_channel_socket, - self._parent_other_socket, + self.__control_shell_channel_socket, + self.__control_other_socket, + self.__parent_shell_channel_socket, + self.__parent_other_socket, ): if socket is not None: socket.close() @@ -84,10 +113,17 @@ def close(self) -> None: break self._stop_subshell(subshell) - def get_control_other_socket(self) -> zmq.asyncio.Socket: + async def get_control_other_socket(self, thread: BaseThread) -> zmq_anyio.Socket: + if not self._control_other_socket.started.is_set(): + await thread.task_group.start(self._control_other_socket.start) return self._control_other_socket - def get_other_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket: + async def get_control_shell_channel_socket(self, thread: BaseThread) -> zmq_anyio.Socket: + if not self._control_shell_channel_socket.started.is_set(): + await thread.task_group.start(self._control_shell_channel_socket.start) + return self._control_shell_channel_socket + + def get_other_socket(self, subshell_id: str | None) -> zmq_anyio.Socket: """Return the other inproc pair socket for a subshell. This socket is accessed from the subshell thread. @@ -99,7 +135,7 @@ def get_other_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket: assert socket is not None return socket - def get_shell_channel_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket: + def get_shell_channel_socket(self, subshell_id: str | None) -> zmq_anyio.Socket: """Return the shell channel inproc pair socket for a subshell. This socket is accessed from the shell channel thread. @@ -117,17 +153,17 @@ def list_subshell(self) -> list[str]: with self._lock_cache: return list(self._cache) - async def listen_from_control(self, subshell_task: t.Any) -> None: + async def listen_from_control(self, subshell_task: t.Any, thread: BaseThread) -> None: """Listen for messages on the control inproc socket, handle those messages and return replies on the same socket. Runs in the shell channel thread. """ assert current_thread().name == SHELL_CHANNEL_THREAD_NAME - socket = self._control_shell_channel_socket + socket = await self.get_control_shell_channel_socket(thread) while True: - request = await socket.recv_json() + request = await socket.arecv_json().wait() reply = await self._process_control_request(request, subshell_task) - await socket.send_json(reply) + await socket.asend_json(reply).wait() async def listen_from_subshells(self) -> None: """Listen for reply messages on inproc sockets of all subshells and resend @@ -138,9 +174,9 @@ async def listen_from_subshells(self) -> None: assert current_thread().name == SHELL_CHANNEL_THREAD_NAME async with create_task_group() as tg: - tg.start_soon(self._listen_for_subshell_reply, None) + tg.start_soon(self._listen_for_subshell_reply, None, tg) async for subshell_id in self._receive_stream: - tg.start_soon(self._listen_for_subshell_reply, subshell_id) + tg.start_soon(self._listen_for_subshell_reply, subshell_id, tg) def subshell_id_from_thread_id(self, thread_id: int) -> str | None: """Return subshell_id of the specified thread_id. @@ -160,10 +196,10 @@ def subshell_id_from_thread_id(self, thread_id: int) -> str | None: def _create_inproc_pair_socket( self, name: str | None, shell_channel_end: bool - ) -> zmq.asyncio.Socket: + ) -> zmq_anyio.Socket: """Create and return a single ZMQ inproc pair socket.""" address = self._get_inproc_socket_address(name) - socket = self._context.socket(zmq.PAIR) + socket = zmq_anyio.Socket(self._context, zmq.PAIR) if shell_channel_end: socket.bind(address) else: @@ -209,7 +245,7 @@ def _get_inproc_socket_address(self, name: str | None) -> str: full_name = f"subshell-{name}" if name else "subshell" return f"inproc://{full_name}" - def _get_shell_channel_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket: + def _get_shell_channel_socket(self, subshell_id: str | None) -> zmq_anyio.Socket: if subshell_id is None: return self._parent_shell_channel_socket with self._lock_cache: @@ -221,7 +257,9 @@ def _is_subshell(self, subshell_id: str | None) -> bool: with self._lock_cache: return subshell_id in self._cache - async def _listen_for_subshell_reply(self, subshell_id: str | None) -> None: + async def _listen_for_subshell_reply( + self, subshell_id: str | None, task_group: TaskGroup + ) -> None: """Listen for reply messages on specified subshell inproc socket and resend to the client via the shell_socket. @@ -231,11 +269,13 @@ async def _listen_for_subshell_reply(self, subshell_id: str | None) -> None: shell_channel_socket = self._get_shell_channel_socket(subshell_id) + if not shell_channel_socket.started.is_set(): + await task_group.start(shell_channel_socket.start) try: while True: - msg = await shell_channel_socket.recv_multipart(copy=False) + msg = await shell_channel_socket.arecv_multipart(copy=False).wait() with self._lock_shell_socket: - await self._shell_socket.send_multipart(msg) + await self._shell_socket.asend_multipart(msg).wait() except BaseException: if not self._is_subshell(subshell_id): # Subshell no longer exists so exit gracefully diff --git a/ipykernel/trio_runner.py b/ipykernel/trio_runner.py deleted file mode 100644 index 6fb44107b..000000000 --- a/ipykernel/trio_runner.py +++ /dev/null @@ -1,72 +0,0 @@ -"""A trio loop runner.""" - -import builtins -import logging -import signal -import threading -import traceback -import warnings - -import trio - - -class TrioRunner: - """A trio loop runner.""" - - def __init__(self): - """Initialize the runner.""" - self._cell_cancel_scope = None - self._trio_token = None - - def initialize(self, kernel, io_loop): - """Initialize the runner.""" - kernel.shell.set_trio_runner(self) - kernel.shell.run_line_magic("autoawait", "trio") - kernel.shell.magics_manager.magics["line"]["autoawait"] = lambda _: warnings.warn( - "Autoawait isn't allowed in Trio background loop mode.", stacklevel=2 - ) - self._interrupted = False - bg_thread = threading.Thread(target=io_loop.start, daemon=True, name="TornadoBackground") - bg_thread.start() - - def interrupt(self, signum, frame): - """Interrupt the runner.""" - if self._cell_cancel_scope: - self._cell_cancel_scope.cancel() - else: - msg = "Kernel interrupted but no cell is running" - raise Exception(msg) - - def run(self): - """Run the loop.""" - old_sig = signal.signal(signal.SIGINT, self.interrupt) - - def log_nursery_exc(exc): - exc = "\n".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) - logging.error("An exception occurred in a global nursery task.\n%s", exc) - - async def trio_main(): - """Run the main loop.""" - self._trio_token = trio.lowlevel.current_trio_token() - async with trio.open_nursery() as nursery: - # TODO This hack prevents the nursery from cancelling all child - # tasks when an uncaught exception occurs, but it's ugly. - nursery._add_exc = log_nursery_exc - builtins.GLOBAL_NURSERY = nursery # type:ignore[attr-defined] - await trio.sleep_forever() - - trio.run(trio_main) - signal.signal(signal.SIGINT, old_sig) - - def __call__(self, async_fn): - """Handle a function call.""" - - async def loc(coro): - """A thread runner context.""" - self._cell_cancel_scope = trio.CancelScope() - with self._cell_cancel_scope: - return await coro - self._cell_cancel_scope = None # type:ignore[unreachable] - return None - - return trio.from_thread.run(loc, async_fn, trio_token=self._trio_token) diff --git a/pyproject.toml b/pyproject.toml index 908ac5370..1853544fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,10 +29,11 @@ dependencies = [ "nest_asyncio>=1.4", "matplotlib-inline>=0.1", 'appnope>=0.1.2;platform_system=="Darwin"', - "pyzmq>=25.0", + "pyzmq>=26.0", "psutil>=5.7", "packaging>=22", - "anyio>=4.2.0", + "anyio>=4.8.0,<5.0.0", + "zmq-anyio >=0.3.6", ] [project.urls] @@ -230,9 +231,12 @@ filterwarnings= [ # ignore unclosed sqlite in traits "ignore:unclosed database in .trigger_timeout' was never awaited", + "ignore: Unclosed socket", + # ignore deprecated non async during tests: "always:For consistency across implementations, it is recommended that:PendingDeprecationWarning", - ] [tool.coverage.report] @@ -342,3 +346,6 @@ ignore = ["W002"] [tool.repo-review] ignore = ["PY007", "PP308", "GH102", "MY101"] + +[tool.hatch.metadata] +allow-direct-references = true diff --git a/tests/conftest.py b/tests/conftest.py index 76e780af3..32524e0ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,15 @@ -import asyncio +import gc import logging -import os import warnings from math import inf +from threading import Event from typing import Any, Callable, no_type_check from unittest.mock import MagicMock import pytest import zmq -import zmq.asyncio -from anyio import create_memory_object_stream, create_task_group +import zmq_anyio +from anyio import create_memory_object_stream, create_task_group, sleep from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from jupyter_client.session import Session @@ -17,6 +17,12 @@ from ipykernel.kernelbase import Kernel from ipykernel.zmqshell import ZMQInteractiveShell + +@pytest.fixture(scope="session", autouse=True) +def _garbage_collection(request): + gc.collect() + + try: import resource except ImportError: @@ -28,12 +34,6 @@ except ModuleNotFoundError: tracemalloc = None - -@pytest.fixture() -def anyio_backend(): - return "asyncio" - - pytestmark = pytest.mark.anyio @@ -52,11 +52,6 @@ def anyio_backend(): resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard)) -# Enforce selector event loop on Windows. -if os.name == "nt": - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type:ignore - - class TestSession(Session): """A session that copies sent messages to an internal stream, so that they can be accessed later. @@ -83,21 +78,21 @@ def send(self, socket, *args, **kwargs): class KernelMixin: - shell_socket: zmq.asyncio.Socket - control_socket: zmq.asyncio.Socket + shell_socket: zmq_anyio.Socket + control_socket: zmq_anyio.Socket stop: Callable[[], None] log = logging.getLogger() def _initialize(self): self._is_test = True - self.context = context = zmq.asyncio.Context() - self.iopub_socket = context.socket(zmq.PUB) - self.stdin_socket = context.socket(zmq.ROUTER) + self.context = context = zmq.Context() + self.iopub_socket = zmq_anyio.Socket(context.socket(zmq.PUB)) + self.stdin_socket = zmq_anyio.Socket(context.socket(zmq.ROUTER)) self.test_sockets = [self.iopub_socket] for name in ["shell", "control"]: - socket = context.socket(zmq.ROUTER) + socket = zmq_anyio.Socket(context.socket(zmq.ROUTER)) self.test_sockets.append(socket) setattr(self, f"{name}_socket", socket) @@ -148,7 +143,7 @@ def _prep_msg(self, *args, **kwargs): async def _wait_for_msg(self): while not self._reply: - await asyncio.sleep(0.1) + await sleep(0.1) _, msg = self.session.feed_identities(self._reply) return self.session.deserialize(msg) @@ -172,6 +167,8 @@ class MockKernel(KernelMixin, Kernel): # type:ignore def __init__(self, *args, **kwargs): self._initialize() self.shell = MagicMock() + self.shell_stop = Event() + self.control_stop = Event() super().__init__(*args, **kwargs) async def do_execute( @@ -193,6 +190,8 @@ async def do_execute( class MockIPyKernel(KernelMixin, IPythonKernel): # type:ignore def __init__(self, *args, **kwargs): self._initialize() + self.shell_stop = Event() + self.control_stop = Event() super().__init__(*args, **kwargs) @@ -201,8 +200,10 @@ async def kernel(anyio_backend): async with create_task_group() as tg: kernel = MockKernel() tg.start_soon(kernel.start) - yield kernel - kernel.destroy() + try: + yield kernel + finally: + kernel.destroy() @pytest.fixture() @@ -210,9 +211,11 @@ async def ipkernel(anyio_backend): async with create_task_group() as tg: kernel = MockIPyKernel() tg.start_soon(kernel.start) - yield kernel - kernel.destroy() - ZMQInteractiveShell.clear_instance() + try: + yield kernel + finally: + kernel.destroy() + ZMQInteractiveShell.clear_instance() @pytest.fixture() diff --git a/tests/test_async.py b/tests/test_async.py index a40db4a00..c2dd980b9 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -30,24 +30,23 @@ def test_async_await(): assert content["status"] == "ok", content -# FIXME: @pytest.mark.parametrize("asynclib", ["asyncio", "trio", "curio"]) @pytest.mark.skipif(os.name == "nt", reason="Cannot interrupt on Windows") -@pytest.mark.parametrize("asynclib", ["asyncio"]) -def test_async_interrupt(asynclib, request): +@pytest.mark.parametrize("anyio_backend", ["asyncio"]) # FIXME: %autoawait trio +def test_async_interrupt(anyio_backend, request): assert KC is not None assert KM is not None try: - __import__(asynclib) + __import__(anyio_backend) except ImportError: - pytest.skip("Requires %s" % asynclib) - request.addfinalizer(lambda: execute("%autoawait asyncio", KC)) + pytest.skip("Requires %s" % anyio_backend) + request.addfinalizer(lambda: execute(f"%autoawait {anyio_backend}", KC)) flush_channels(KC) - msg_id, content = execute("%autoawait " + asynclib, KC) + msg_id, content = execute(f"%autoawait {anyio_backend}", KC) assert content["status"] == "ok", content flush_channels(KC) - msg_id = KC.execute(f"print('begin'); import {asynclib}; await {asynclib}.sleep(5)") + msg_id = KC.execute(f"print('begin'); import {anyio_backend}; await {anyio_backend}.sleep(5)") busy = KC.get_iopub_msg(timeout=TIMEOUT) validate_message(busy, "status", msg_id) assert busy["content"]["execution_state"] == "busy" diff --git a/tests/test_embed_kernel.py b/tests/test_embed_kernel.py index 0c74dd1f0..37d1aaf46 100644 --- a/tests/test_embed_kernel.py +++ b/tests/test_embed_kernel.py @@ -145,7 +145,10 @@ def test_embed_kernel_namespace(): with setup_kernel(cmd) as client: # oinfo a (int) client.inspect("a") - msg = client.get_shell_msg(timeout=TIMEOUT) + while True: + msg = client.get_shell_msg(timeout=TIMEOUT) + if msg["msg_type"] == "inspect_reply": + break content = msg["content"] assert content["found"] text = content["data"]["text/plain"] @@ -153,7 +156,10 @@ def test_embed_kernel_namespace(): # oinfo b (str) client.inspect("b") - msg = client.get_shell_msg(timeout=TIMEOUT) + while True: + msg = client.get_shell_msg(timeout=TIMEOUT) + if msg["msg_type"] == "inspect_reply": + break content = msg["content"] assert content["found"] text = content["data"]["text/plain"] @@ -161,7 +167,10 @@ def test_embed_kernel_namespace(): # oinfo c (undefined) client.inspect("c") - msg = client.get_shell_msg(timeout=TIMEOUT) + while True: + msg = client.get_shell_msg(timeout=TIMEOUT) + if msg["msg_type"] == "inspect_reply": + break content = msg["content"] assert not content["found"] @@ -186,7 +195,10 @@ def test_embed_kernel_reentrant(): with setup_kernel(cmd) as client: for i in range(5): client.inspect("count") - msg = client.get_shell_msg(timeout=TIMEOUT) + while True: + msg = client.get_shell_msg(timeout=TIMEOUT) + if msg["msg_type"] == "inspect_reply": + break content = msg["content"] assert content["found"] text = content["data"]["text/plain"] diff --git a/tests/test_eventloop.py b/tests/test_eventloop.py index 7dc0106c8..e4aef5711 100644 --- a/tests/test_eventloop.py +++ b/tests/test_eventloop.py @@ -79,6 +79,7 @@ def do_thing(): @windows_skip +@pytest.mark.parametrize("anyio_backend", ["asyncio"]) def test_asyncio_loop(kernel): def do_thing(): loop.call_later(0.01, loop.stop) diff --git a/tests/test_io.py b/tests/test_io.py index e3ff28159..17f955af6 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -12,31 +12,42 @@ import pytest import zmq -import zmq.asyncio +import zmq_anyio +from anyio import create_task_group from jupyter_client.session import Session from ipykernel.iostream import _PARENT, BackgroundSocket, IOPubThread, OutStream +pytestmark = pytest.mark.anyio + @pytest.fixture() def ctx(): - ctx = zmq.asyncio.Context() + ctx = zmq.Context() yield ctx ctx.destroy() @pytest.fixture() -def iopub_thread(ctx): - with ctx.socket(zmq.PUB) as pub: - thread = IOPubThread(pub) - thread.start() - - yield thread - thread.stop() - thread.close() - - -def test_io_api(iopub_thread): +async def iopub_thread(ctx): + try: + async with create_task_group() as tg: + pub = zmq_anyio.Socket(ctx.socket(zmq.PUB)) + await tg.start(pub.start) + thread = IOPubThread(pub) + thread.start() + + try: + yield thread + finally: + await pub.stop() + thread.stop() + thread.close() + except BaseException: + pass + + +async def test_io_api(iopub_thread): """Test that wrapped stdout has the same API as a normal TextIO object""" session = Session() stream = OutStream(session, iopub_thread, "stdout") @@ -59,13 +70,13 @@ def test_io_api(iopub_thread): stream.write(b"") # type:ignore -def test_io_isatty(iopub_thread): +async def test_io_isatty(iopub_thread): session = Session() stream = OutStream(session, iopub_thread, "stdout", isatty=True) assert stream.isatty() -async def test_io_thread(anyio_backend, iopub_thread): +async def test_io_thread(iopub_thread): thread = iopub_thread thread._setup_pipe_in() msg = [thread._pipe_uuid, b"a"] @@ -77,11 +88,9 @@ async def test_io_thread(anyio_backend, iopub_thread): thread._really_send([b"hi"]) ctx1.destroy() thread.stop() - thread.close() - thread._really_send(None) -async def test_background_socket(anyio_backend, iopub_thread): +async def test_background_socket(iopub_thread): sock = BackgroundSocket(iopub_thread) assert sock.__class__ == BackgroundSocket with warnings.catch_warnings(): @@ -92,7 +101,7 @@ async def test_background_socket(anyio_backend, iopub_thread): sock.send(b"hi") -async def test_outstream(anyio_backend, iopub_thread): +async def test_outstream(iopub_thread): session = Session() pub = iopub_thread.socket @@ -118,7 +127,7 @@ async def test_outstream(anyio_backend, iopub_thread): assert stream.writable() -@pytest.mark.anyio() +@pytest.mark.skip(reason="Cannot use a zmq-anyio socket on different threads") async def test_event_pipe_gc(iopub_thread): session = Session(key=b"abc") stream = OutStream( @@ -139,7 +148,7 @@ async def test_event_pipe_gc(iopub_thread): f: Future = Future() try: - await iopub_thread._event_pipe_gc() + iopub_thread._event_pipe_gc() except Exception as e: f.set_exception(e) else: @@ -150,12 +159,13 @@ async def test_event_pipe_gc(iopub_thread): # assert iopub_thread._event_pipes == {} -def subprocess_test_echo_watch(): +async def subprocess_test_echo_watch(): # handshake Pub subscription session = Session(key=b"abc") # use PUSH socket to avoid subscription issues - with zmq.asyncio.Context() as ctx, ctx.socket(zmq.PUSH) as pub: + with zmq.Context() as ctx: + pub = zmq_anyio.Socket(ctx.socket(zmq.PUSH)) pub.connect(os.environ["IOPUB_URL"]) iopub_thread = IOPubThread(pub) iopub_thread.start() @@ -192,19 +202,18 @@ def subprocess_test_echo_watch(): iopub_thread.close() -@pytest.mark.anyio() @pytest.mark.skipif(sys.platform.startswith("win"), reason="Windows") async def test_echo_watch(ctx): """Test echo on underlying FD while capturing the same FD Test runs in a subprocess to avoid messing with pytest output capturing. """ - s = ctx.socket(zmq.PULL) + s = zmq_anyio.Socket(ctx.socket(zmq.PULL)) port = s.bind_to_random_port("tcp://127.0.0.1") url = f"tcp://127.0.0.1:{port}" session = Session(key=b"abc") stdout_chunks = [] - with s: + async with s: env = dict(os.environ) env["IOPUB_URL"] = url env["PYTHONUNBUFFERED"] = "1" @@ -213,7 +222,7 @@ async def test_echo_watch(ctx): [ sys.executable, "-c", - f"import {__name__}; {__name__}.subprocess_test_echo_watch()", + f"import {__name__}, anyio; anyio.run({__name__}.subprocess_test_echo_watch)", ], env=env, capture_output=True, @@ -224,8 +233,8 @@ async def test_echo_watch(ctx): print(f"{p.stdout=}") print(f"{p.stderr}=", file=sys.stderr) assert p.returncode == 0 - while await s.poll(timeout=100): - msg = await s.recv_multipart() + while await s.apoll(timeout=100).wait(): + msg = await s.arecv_multipart().wait() ident, msg = session.feed_identities(msg, copy=True) msg = session.deserialize(msg, content=True, copy=True) assert msg is not None # for type narrowing diff --git a/tests/test_kernelapp.py b/tests/test_kernelapp.py index 0f1d04373..cc010740d 100644 --- a/tests/test_kernelapp.py +++ b/tests/test_kernelapp.py @@ -130,7 +130,7 @@ async def trigger_stop(): app.kernel = MockKernel() app.init_sockets() async with trio.open_nursery() as nursery: - nursery.start_soon(app._start) + nursery.start_soon(lambda: app._start("trio")) nursery.start_soon(trigger_stop) app.cleanup_connection_file() app.kernel.destroy() diff --git a/tests/test_start_kernel.py b/tests/test_start_kernel.py index 71f4bdc0a..b8eaf22d9 100644 --- a/tests/test_start_kernel.py +++ b/tests/test_start_kernel.py @@ -32,7 +32,10 @@ def test_ipython_start_kernel_userns(): with setup_kernel(cmd) as client: client.inspect("custom") - msg = client.get_shell_msg(timeout=TIMEOUT) + while True: + msg = client.get_shell_msg(timeout=TIMEOUT) + if msg["msg_type"] == "inspect_reply": + break content = msg["content"] assert content["found"] text = content["data"]["text/plain"] @@ -44,7 +47,10 @@ def test_ipython_start_kernel_userns(): content = msg["content"] assert content["status"] == "ok" client.inspect("usermod") - msg = client.get_shell_msg(timeout=TIMEOUT) + while True: + msg = client.get_shell_msg(timeout=TIMEOUT) + if msg["msg_type"] == "inspect_reply": + break content = msg["content"] assert content["found"] text = content["data"]["text/plain"] @@ -68,7 +74,10 @@ def test_ipython_start_kernel_no_userns(): content = msg["content"] assert content["status"] == "ok" client.inspect("usermod") - msg = client.get_shell_msg(timeout=TIMEOUT) + while True: + msg = client.get_shell_msg(timeout=TIMEOUT) + if msg["msg_type"] == "inspect_reply": + break content = msg["content"] assert content["found"] text = content["data"]["text/plain"] diff --git a/tests/test_zmq_shell.py b/tests/test_zmq_shell.py index 8a8fe042b..33d23a59e 100644 --- a/tests/test_zmq_shell.py +++ b/tests/test_zmq_shell.py @@ -211,46 +211,53 @@ def test_unregister_hook(self): def test_magics(tmp_path): - context = zmq.Context() - socket = context.socket(zmq.PUB) - shell = InteractiveShell() - shell.user_ns["hi"] = 1 - magics = KernelMagics(shell) - - tmp_file = tmp_path / "test.txt" - tmp_file.write_text("hi", "utf8") - magics.edit(str(tmp_file)) - payload = shell.payload_manager.read_payload()[0] - assert payload["filename"] == str(tmp_file) - - magics.clear([]) - magics.less(str(tmp_file)) - if os.name == "posix": - magics.man("ls") - magics.autosave("10") - - socket.close() - context.destroy() + try: + context = zmq.Context() + socket = context.socket(zmq.PUB) + shell = InteractiveShell() + shell.user_ns["hi"] = 1 + magics = KernelMagics(shell) + + tmp_file = tmp_path / "test.txt" + tmp_file.write_text("hi", "utf8") + magics.edit(str(tmp_file)) + payload = shell.payload_manager.read_payload()[0] + assert payload["filename"] == str(tmp_file) + + magics.clear([]) + magics.less(str(tmp_file)) + if os.name == "posix": + magics.man("ls") + magics.autosave("10") + finally: + socket.close() + context.destroy() + shell.configurables = [] + InteractiveShell.clear_instance() def test_zmq_interactive_shell(kernel): - shell = ZMQInteractiveShell() - - with pytest.raises(RuntimeError): - shell.enable_gui("tk") - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - shell.data_pub_class = MagicMock() # type:ignore - shell.data_pub - shell.kernel = kernel - shell.set_next_input("hi") - assert shell.get_parent() is None - if os.name == "posix": - shell.system_piped("ls") - else: - shell.system_piped("dir") - shell.ask_exit() + try: + shell = ZMQInteractiveShell() + + with pytest.raises(RuntimeError): + shell.enable_gui("tk") + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + shell.data_pub_class = MagicMock() # type:ignore + shell.data_pub + shell.kernel = kernel + shell.set_next_input("hi") + assert shell.get_parent() is None + if os.name == "posix": + shell.system_piped("ls") + else: + shell.system_piped("dir") + shell.ask_exit() + finally: + shell.configurables = [] + ZMQInteractiveShell.clear_instance() if __name__ == "__main__":