@@ -60,8 +60,8 @@ def __init__(
6060 # Initialized later...
6161 self .comm_socket : socket | None = None
6262
63- def build_connection_file (self ) -> None :
64- ports : list = self ._select_ports (5 )
63+ def build_connection_file (self ) -> int :
64+ ports : list = self ._select_ports (6 )
6565 write_connection_file (
6666 fname = self .conn_filename ,
6767 ip = "0.0.0.0" , # noqa: S104
@@ -72,6 +72,7 @@ def build_connection_file(self) -> None:
7272 hb_port = ports [3 ],
7373 control_port = ports [4 ],
7474 )
75+ return ports [5 ]
7576
7677 def _encrypt (self , connection_info_bytes : bytes ) -> bytes :
7778 """Encrypt the connection information using a generated AES key that is then encrypted using
@@ -98,7 +99,7 @@ def _encrypt(self, connection_info_bytes: bytes) -> bytes:
9899 b64_payload = base64 .b64encode (json .dumps (payload ).encode (encoding = "utf-8" ))
99100 return b64_payload
100101
101- def return_connection_info (self ) -> None :
102+ def return_connection_info (self , comm_port ) -> None :
102103 """Returns the connection information corresponding to this kernel."""
103104 response_parts = self .response_addr .split (":" )
104105 if len (response_parts ) != 2 :
@@ -125,7 +126,7 @@ def return_connection_info(self) -> None:
125126 cf_json ["pgid" ] = os .getpgid (self .parent_pid )
126127
127128 # prepare socket address for handling signals
128- self .prepare_comm_socket () # self.comm_socket initialized
129+ self .prepare_comm_socket (comm_port ) # self.comm_socket initialized
129130 cf_json ["comm_port" ] = self .comm_socket .getsockname ()[1 ]
130131 cf_json ["kernel_id" ] = self .kernel_id
131132
@@ -137,9 +138,11 @@ def return_connection_info(self) -> None:
137138 logger .debug (f"Encrypted Payload '{ payload } " )
138139 s .send (payload )
139140
140- def prepare_comm_socket (self ) -> None :
141+ def prepare_comm_socket (self , comm_port ) -> None :
141142 """Prepares the socket to which the server will send signal and shutdown requests."""
142- self .comm_socket = self ._select_socket ()
143+ sock = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
144+ sock .bind (("0.0.0.0" , comm_port ))
145+ self .comm_socket = sock
143146 logger .info (
144147 f"Signal socket bound to host: "
145148 f"{ self .comm_socket .getsockname ()[0 ]} , port: { self .comm_socket .getsockname ()[1 ]} "
@@ -212,7 +215,7 @@ def get_server_request(self) -> dict:
212215
213216 return request_info
214217
215- def process_requests (self ) -> None :
218+ def process_requests (self , comm_port ) -> None :
216219 """Waits for requests from the server and processes each when received. Currently,
217220 these will be one of a sending a signal to the corresponding kernel process (signum) or
218221 stopping the listener and exiting the kernel (shutdown).
@@ -225,7 +228,7 @@ def process_requests(self) -> None:
225228
226229 # Since this creates the communication socket, we should do this here so the socket
227230 # gets created in the sub-process/thread. This is necessary on MacOS/Python.
228- self .return_connection_info ()
231+ self .return_connection_info (comm_port )
229232
230233 while not shutdown :
231234 request = self .get_server_request ()
@@ -277,10 +280,11 @@ def setup_server_listener(
277280 public_key ,
278281 cluster_type ,
279282 )
280- sl .build_connection_file ()
283+ comm_port = sl .build_connection_file ()
281284
282285 set_start_method ("fork" )
283- server_listener = Process (target = sl .process_requests )
286+ # server_listener = Process(target=sl.process_requests)
287+ server_listener = Process (target = sl .process_requests , args = (comm_port ,))
284288 server_listener .start ()
285289
286290
0 commit comments