@@ -118,6 +118,11 @@ async def get_control_other_socket(self, thread: BaseThread) -> zmq_anyio.Socket
118118 await thread .task_group .start (self ._control_other_socket .start )
119119 return self ._control_other_socket
120120
121+ async def get_control_shell_channel_socket (self , thread : BaseThread ) -> zmq_anyio .Socket :
122+ if not self ._control_shell_channel_socket .started .is_set ():
123+ await thread .task_group .start (self ._control_shell_channel_socket .start )
124+ return self ._control_shell_channel_socket
125+
121126 def get_other_socket (self , subshell_id : str | None ) -> zmq_anyio .Socket :
122127 """Return the other inproc pair socket for a subshell.
123128
@@ -148,18 +153,17 @@ def list_subshell(self) -> list[str]:
148153 with self ._lock_cache :
149154 return list (self ._cache )
150155
151- async def listen_from_control (self , subshell_task : t .Any ) -> None :
156+ async def listen_from_control (self , subshell_task : t .Any , thread : BaseThread ) -> None :
152157 """Listen for messages on the control inproc socket, handle those messages and
153158 return replies on the same socket. Runs in the shell channel thread.
154159 """
155160 assert current_thread ().name == SHELL_CHANNEL_THREAD_NAME
156161
157- socket = self ._control_shell_channel_socket
158- async with socket :
159- while True :
160- request = await socket .arecv_json ().wait ()
161- reply = await self ._process_control_request (request , subshell_task )
162- await socket .asend_json (reply ).wait ()
162+ socket = await self .get_control_shell_channel_socket (thread )
163+ while True :
164+ request = await socket .arecv_json ().wait ()
165+ reply = await self ._process_control_request (request , subshell_task )
166+ await socket .asend_json (reply ).wait ()
163167
164168 async def listen_from_subshells (self ) -> None :
165169 """Listen for reply messages on inproc sockets of all subshells and resend
@@ -265,8 +269,8 @@ async def _listen_for_subshell_reply(
265269
266270 shell_channel_socket = self ._get_shell_channel_socket (subshell_id )
267271
268- task_group . start_soon ( shell_channel_socket .start )
269- await shell_channel_socket . started . wait ( )
272+ if not shell_channel_socket .started . is_set ():
273+ await task_group . start ( shell_channel_socket . start )
270274 try :
271275 while True :
272276 msg = await shell_channel_socket .arecv_multipart (copy = False ).wait ()
0 commit comments