@@ -60,8 +60,8 @@ def __init__(
60
60
# Initialized later...
61
61
self .comm_socket : socket | None = None
62
62
63
- def build_connection_file (self ) -> None :
64
- ports : list = self ._select_ports (5 )
63
+ def build_connection_file (self ) -> int :
64
+ ports : list = self ._select_ports (6 )
65
65
write_connection_file (
66
66
fname = self .conn_filename ,
67
67
ip = "0.0.0.0" , # noqa: S104
@@ -72,6 +72,7 @@ def build_connection_file(self) -> None:
72
72
hb_port = ports [3 ],
73
73
control_port = ports [4 ],
74
74
)
75
+ return ports [5 ]
75
76
76
77
def _encrypt (self , connection_info_bytes : bytes ) -> bytes :
77
78
"""Encrypt the connection information using a generated AES key that is then encrypted using
@@ -98,7 +99,7 @@ def _encrypt(self, connection_info_bytes: bytes) -> bytes:
98
99
b64_payload = base64 .b64encode (json .dumps (payload ).encode (encoding = "utf-8" ))
99
100
return b64_payload
100
101
101
- def return_connection_info (self ) -> None :
102
+ def return_connection_info (self , comm_port ) -> None :
102
103
"""Returns the connection information corresponding to this kernel."""
103
104
response_parts = self .response_addr .split (":" )
104
105
if len (response_parts ) != 2 :
@@ -125,7 +126,7 @@ def return_connection_info(self) -> None:
125
126
cf_json ["pgid" ] = os .getpgid (self .parent_pid )
126
127
127
128
# prepare socket address for handling signals
128
- self .prepare_comm_socket () # self.comm_socket initialized
129
+ self .prepare_comm_socket (comm_port ) # self.comm_socket initialized
129
130
cf_json ["comm_port" ] = self .comm_socket .getsockname ()[1 ]
130
131
cf_json ["kernel_id" ] = self .kernel_id
131
132
@@ -137,9 +138,11 @@ def return_connection_info(self) -> None:
137
138
logger .debug (f"Encrypted Payload '{ payload } " )
138
139
s .send (payload )
139
140
140
- def prepare_comm_socket (self ) -> None :
141
+ def prepare_comm_socket (self , comm_port ) -> None :
141
142
"""Prepares the socket to which the server will send signal and shutdown requests."""
142
- self .comm_socket = self ._select_socket ()
143
+ sock = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
144
+ sock .bind (("0.0.0.0" , comm_port ))
145
+ self .comm_socket = sock
143
146
logger .info (
144
147
f"Signal socket bound to host: "
145
148
f"{ self .comm_socket .getsockname ()[0 ]} , port: { self .comm_socket .getsockname ()[1 ]} "
@@ -212,7 +215,7 @@ def get_server_request(self) -> dict:
212
215
213
216
return request_info
214
217
215
- def process_requests (self ) -> None :
218
+ def process_requests (self , comm_port ) -> None :
216
219
"""Waits for requests from the server and processes each when received. Currently,
217
220
these will be one of a sending a signal to the corresponding kernel process (signum) or
218
221
stopping the listener and exiting the kernel (shutdown).
@@ -225,7 +228,7 @@ def process_requests(self) -> None:
225
228
226
229
# Since this creates the communication socket, we should do this here so the socket
227
230
# gets created in the sub-process/thread. This is necessary on MacOS/Python.
228
- self .return_connection_info ()
231
+ self .return_connection_info (comm_port )
229
232
230
233
while not shutdown :
231
234
request = self .get_server_request ()
@@ -277,10 +280,11 @@ def setup_server_listener(
277
280
public_key ,
278
281
cluster_type ,
279
282
)
280
- sl .build_connection_file ()
283
+ comm_port = sl .build_connection_file ()
281
284
282
285
set_start_method ("fork" )
283
- server_listener = Process (target = sl .process_requests )
286
+ # server_listener = Process(target=sl.process_requests)
287
+ server_listener = Process (target = sl .process_requests , args = (comm_port ,))
284
288
server_listener .start ()
285
289
286
290
0 commit comments