|
42 | 42 | import zmq_anyio |
43 | 43 | from anyio import ( |
44 | 44 | TASK_STATUS_IGNORED, |
| 45 | + CancelScope, |
45 | 46 | Event, |
46 | 47 | create_memory_object_stream, |
47 | 48 | create_task_group, |
48 | 49 | sleep, |
49 | 50 | to_thread, |
50 | 51 | ) |
51 | | -from anyio.abc import TaskStatus |
| 52 | +from anyio.abc import TaskGroup, TaskStatus |
52 | 53 | from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
53 | 54 | from IPython.core.error import StdinNotImplementedError |
54 | 55 | from jupyter_client.session import Session |
@@ -131,6 +132,7 @@ class Kernel(SingletonConfigurable): |
131 | 132 | _send_exec_request: Dict[dict[zmq_anyio.Socket, MemoryObjectSendStream]] = Dict() |
132 | 133 | _main_subshell_ready = Instance(Event, ()) |
133 | 134 | asyncio_event_loop = Instance(asyncio.AbstractEventLoop, allow_none=True, read_only=True) # type:ignore[call-overload] |
| 135 | + tg = Instance(TaskGroup, read_only=True) |
134 | 136 |
|
135 | 137 | log: logging.Logger = Instance(logging.Logger, allow_none=True) # type:ignore[assignment] |
136 | 138 |
|
@@ -441,23 +443,23 @@ async def shell_main(self, subshell_id: str | None): |
441 | 443 | if not socket.started.is_set(): |
442 | 444 | await tg.start(socket.start) |
443 | 445 | tg.start_soon(self._process_shell, socket) |
444 | | - tg.start_soon(self._execute_request_handler, receive_stream, subshell_id) |
| 446 | + tg.start_soon(self._execute_request_loop, receive_stream) |
445 | 447 | if subshell_id is None: |
446 | 448 | # Main subshell. |
447 | | - await to_thread.run_sync(self.shell_stop.wait) |
448 | | - tg.cancel_scope.cancel() |
| 449 | + with contextlib.suppress(RuntimeError): |
| 450 | + self.set_trait("asyncio_event_loop", asyncio.get_running_loop()) |
| 451 | + async with create_task_group() as tg_main: |
| 452 | + with CancelScope(shield=True) as scope: |
| 453 | + self.set_trait("tg", tg_main) |
| 454 | + self._main_subshell_ready.set() |
| 455 | + await to_thread.run_sync(self.shell_stop.wait) |
| 456 | + scope.cancel() |
| 457 | + tg.cancel_scope.cancel() |
449 | 458 | self._send_exec_request.pop(socket, None) |
450 | | - self.set_trait("asyncio_event_loop", None) |
451 | 459 | await send_stream.aclose() |
452 | 460 | await receive_stream.aclose() |
453 | 461 |
|
454 | | - async def _execute_request_handler( |
455 | | - self, receive_stream: MemoryObjectReceiveStream, subshell_id: str | None |
456 | | - ): |
457 | | - if subshell_id is None: |
458 | | - with contextlib.suppress(RuntimeError): |
459 | | - self.set_trait("asyncio_event_loop", asyncio.get_running_loop()) |
460 | | - self._main_subshell_ready.set() |
| 462 | + async def _execute_request_loop(self, receive_stream: MemoryObjectReceiveStream): |
461 | 463 | async with receive_stream: |
462 | 464 | async for handler, (received_time, socket, idents, msg) in receive_stream: |
463 | 465 | try: |
|
0 commit comments