4040import zmq_anyio
4141from anyio import (
4242 TASK_STATUS_IGNORED ,
43+ Event ,
4344 create_memory_object_stream ,
4445 create_task_group ,
4546 sleep ,
@@ -126,6 +127,7 @@ class Kernel(SingletonConfigurable):
126127 stdin_socket = Any ()
127128
128129 _send_exec_request : Dict [dict [zmq_anyio .Socket , MemoryObjectSendStream ]] = Dict ()
130+ _main_shell_ready = Instance (Event , ())
129131
130132 log : logging .Logger = Instance (logging .Logger , allow_none = True ) # type:ignore[assignment]
131133
@@ -436,13 +438,15 @@ async def shell_main(self, subshell_id: str | None):
436438 async with create_task_group () as tg :
437439 if not socket .started .is_set ():
438440 await tg .start (socket .start )
439- tg .start_soon (self .process_shell , socket )
441+ tg .start_soon (self ._process_shell , socket )
440442 tg .start_soon (self ._execute_request_handler , receive_stream )
441443 if subshell_id is None :
442444 # Main subshell.
445+ self ._main_shell_ready .set ()
443446 await to_thread .run_sync (self .shell_stop .wait )
444447 tg .cancel_scope .cancel ()
445448 self ._send_exec_request .pop (socket , None )
449+ await send_stream .aclose ()
446450
447451 async def _execute_request_handler (self , receive_stream : MemoryObjectReceiveStream ):
448452 async with receive_stream :
@@ -461,8 +465,9 @@ async def _execute_request_handler(self, receive_stream: MemoryObjectReceiveStre
461465 except BaseException as e :
462466 self .log .exception ("Execute request" , exc_info = e )
463467
464- async def process_shell (self , socket ):
468+ async def _process_shell (self , socket ):
465469 # socket=None is valid if kernel subshells are not supported.
470+ await self ._main_shell_ready .wait ()
466471 try :
467472 while True :
468473 await self .process_shell_message (socket = socket )
@@ -476,15 +481,13 @@ async def process_shell_message(self, msg=None, socket=None):
476481 # If msg is set, process that message.
477482 # If msg is None, await the next message to arrive on the socket.
478483 assert self .session is not None
484+ socket = socket or self .shell_socket
479485 if self ._supports_kernel_subshells :
480486 assert threading .current_thread () not in (
481487 self .control_thread ,
482488 self .shell_channel_thread ,
483489 )
484490 assert socket is not None
485- else :
486- assert threading .current_thread () == threading .main_thread ()
487- socket = self .shell_socket
488491
489492 msg = msg or await socket .arecv_multipart (copy = False ).wait ()
490493
@@ -532,8 +535,8 @@ async def process_shell_message(self, msg=None, socket=None):
532535 result = handler (socket , idents , msg )
533536 if inspect .isawaitable (result ):
534537 await result
535- except Exception :
536- self .log .error ("Exception in message handler:" , exc_info = True ) # noqa: G201
538+ except Exception as e :
539+ self .log .error ("Exception in message handler:" , exc_info = e )
537540 except KeyboardInterrupt :
538541 # Ctrl-c shouldn't crash the kernel here.
539542 self .log .error ("KeyboardInterrupt caught in kernel." )
@@ -583,6 +586,7 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:
583586 self .shell_stop = threading .Event ()
584587
585588 tg .start_soon (self .shell_main , None )
589+ await self ._main_shell_ready .wait ()
586590 if self .shell_channel_thread :
587591 # Assign tasks to and start shell channel thread.
588592 manager = self .shell_channel_thread .manager
0 commit comments