Skip to content

Commit aaa29fe

Browse files
committed
Add anyio Event for _main_shell_ready to aid with start and test reliability.
1 parent bee82d1 commit aaa29fe

File tree

4 files changed

+22
-17
lines changed

4 files changed

+22
-17
lines changed

ipykernel/kernelapp.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import zmq
2222
import zmq_anyio
23-
from anyio import create_task_group, run, to_thread
23+
from anyio import create_task_group, run
2424
from IPython.core.application import ( # type:ignore[attr-defined]
2525
BaseIPythonApplication,
2626
base_aliases,
@@ -764,15 +764,9 @@ def start(self) -> None:
764764
backend = "trio" if self.trio_loop else "asyncio"
765765
run(partial(self._start, backend), backend=backend)
766766

767-
async def _wait_to_enter_eventloop(self) -> None:
768-
await to_thread.run_sync(self.kernel._eventloop_set.wait)
769-
await self.kernel.enter_eventloop()
770-
771767
async def main(self) -> None:
772768
async with create_task_group() as tg:
773-
tg.start_soon(self._wait_to_enter_eventloop)
774769
tg.start_soon(self.kernel.start)
775-
776770
if self.kernel.eventloop:
777771
self.kernel._eventloop_set.set()
778772

ipykernel/kernelbase.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import zmq_anyio
4141
from anyio import (
4242
TASK_STATUS_IGNORED,
43+
Event,
4344
create_memory_object_stream,
4445
create_task_group,
4546
sleep,
@@ -126,6 +127,7 @@ class Kernel(SingletonConfigurable):
126127
stdin_socket = Any()
127128

128129
_send_exec_request: Dict[dict[zmq_anyio.Socket, MemoryObjectSendStream]] = Dict()
130+
_main_shell_ready = Instance(Event, ())
129131

130132
log: logging.Logger = Instance(logging.Logger, allow_none=True) # type:ignore[assignment]
131133

@@ -436,13 +438,15 @@ async def shell_main(self, subshell_id: str | None):
436438
async with create_task_group() as tg:
437439
if not socket.started.is_set():
438440
await tg.start(socket.start)
439-
tg.start_soon(self.process_shell, socket)
441+
tg.start_soon(self._process_shell, socket)
440442
tg.start_soon(self._execute_request_handler, receive_stream)
441443
if subshell_id is None:
442444
# Main subshell.
445+
self._main_shell_ready.set()
443446
await to_thread.run_sync(self.shell_stop.wait)
444447
tg.cancel_scope.cancel()
445448
self._send_exec_request.pop(socket, None)
449+
await send_stream.aclose()
446450

447451
async def _execute_request_handler(self, receive_stream: MemoryObjectReceiveStream):
448452
async with receive_stream:
@@ -461,8 +465,9 @@ async def _execute_request_handler(self, receive_stream: MemoryObjectReceiveStre
461465
except BaseException as e:
462466
self.log.exception("Execute request", exc_info=e)
463467

464-
async def process_shell(self, socket):
468+
async def _process_shell(self, socket):
465469
# socket=None is valid if kernel subshells are not supported.
470+
await self._main_shell_ready.wait()
466471
try:
467472
while True:
468473
await self.process_shell_message(socket=socket)
@@ -476,15 +481,13 @@ async def process_shell_message(self, msg=None, socket=None):
476481
# If msg is set, process that message.
477482
# If msg is None, await the next message to arrive on the socket.
478483
assert self.session is not None
484+
socket = socket or self.shell_socket
479485
if self._supports_kernel_subshells:
480486
assert threading.current_thread() not in (
481487
self.control_thread,
482488
self.shell_channel_thread,
483489
)
484490
assert socket is not None
485-
else:
486-
assert threading.current_thread() == threading.main_thread()
487-
socket = self.shell_socket
488491

489492
msg = msg or await socket.arecv_multipart(copy=False).wait()
490493

@@ -532,8 +535,8 @@ async def process_shell_message(self, msg=None, socket=None):
532535
result = handler(socket, idents, msg)
533536
if inspect.isawaitable(result):
534537
await result
535-
except Exception:
536-
self.log.error("Exception in message handler:", exc_info=True) # noqa: G201
538+
except Exception as e:
539+
self.log.error("Exception in message handler:", exc_info=e)
537540
except KeyboardInterrupt:
538541
# Ctrl-c shouldn't crash the kernel here.
539542
self.log.error("KeyboardInterrupt caught in kernel.")
@@ -583,6 +586,7 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:
583586
self.shell_stop = threading.Event()
584587

585588
tg.start_soon(self.shell_main, None)
589+
await self._main_shell_ready.wait()
586590
if self.shell_channel_thread:
587591
# Assign tasks to and start shell channel thread.
588592
manager = self.shell_channel_thread.manager

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ async def kernel(anyio_backend):
207207
async with create_task_group() as tg:
208208
kernel = MockKernel()
209209
tg.start_soon(kernel.start)
210+
await kernel._main_shell_ready.wait()
210211
try:
211212
yield kernel
212213
finally:
@@ -218,6 +219,7 @@ async def ipkernel(anyio_backend):
218219
async with create_task_group() as tg:
219220
kernel = MockIPyKernel()
220221
tg.start_soon(kernel.start)
222+
await kernel._main_shell_ready.wait()
221223
try:
222224
yield kernel
223225
finally:

tests/test_ipkernel_direct.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,19 @@ async def test_direct_kernel_info_request(ipkernel):
3434

3535

3636
async def test_direct_execute_request(ipkernel: MockIPyKernel) -> None:
37-
reply = await ipkernel.test_shell_message("execute_request", dict(code="hello", silent=False))
37+
reply = await ipkernel.test_shell_message(
38+
"execute_request", dict(code="invalid_call()", silent=False)
39+
)
3840
assert reply["header"]["msg_type"] == "execute_reply"
41+
ipkernel._aborted_time += 10
3942
reply = await ipkernel.test_shell_message(
4043
"execute_request", dict(code="trigger_error", silent=False)
4144
)
4245
assert reply["content"]["status"] == "aborted"
43-
44-
reply = await ipkernel.test_shell_message("execute_request", dict(code="hello", silent=False))
46+
ipkernel._aborted_time = time.monotonic()
47+
reply = await ipkernel.test_shell_message(
48+
"execute_request", dict(code="okay=True", silent=False)
49+
)
4550
assert reply["header"]["msg_type"] == "execute_reply"
4651

4752

0 commit comments

Comments
 (0)