Skip to content

Commit 700090a

Browse files
hjliu0206huajie.liu
andauthored
fix communication port duplicate issue (#134)
Co-authored-by: huajie.liu <[email protected]>
1 parent aaf4835 commit 700090a

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

gateway_provisioners/kernel-launchers/shared/scripts/server_listener.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def __init__(
6060
# Initialized later...
6161
self.comm_socket: socket | None = None
6262

63-
def build_connection_file(self) -> None:
64-
ports: list = self._select_ports(5)
63+
def build_connection_file(self) -> int:
64+
ports: list = self._select_ports(6)
6565
write_connection_file(
6666
fname=self.conn_filename,
6767
ip="0.0.0.0", # noqa: S104
@@ -72,6 +72,7 @@ def build_connection_file(self) -> None:
7272
hb_port=ports[3],
7373
control_port=ports[4],
7474
)
75+
return ports[5]
7576

7677
def _encrypt(self, connection_info_bytes: bytes) -> bytes:
7778
"""Encrypt the connection information using a generated AES key that is then encrypted using
@@ -98,7 +99,7 @@ def _encrypt(self, connection_info_bytes: bytes) -> bytes:
9899
b64_payload = base64.b64encode(json.dumps(payload).encode(encoding="utf-8"))
99100
return b64_payload
100101

101-
def return_connection_info(self) -> None:
102+
def return_connection_info(self, comm_port) -> None:
102103
"""Returns the connection information corresponding to this kernel."""
103104
response_parts = self.response_addr.split(":")
104105
if len(response_parts) != 2:
@@ -125,7 +126,7 @@ def return_connection_info(self) -> None:
125126
cf_json["pgid"] = os.getpgid(self.parent_pid)
126127

127128
# prepare socket address for handling signals
128-
self.prepare_comm_socket() # self.comm_socket initialized
129+
self.prepare_comm_socket(comm_port) # self.comm_socket initialized
129130
cf_json["comm_port"] = self.comm_socket.getsockname()[1]
130131
cf_json["kernel_id"] = self.kernel_id
131132

@@ -137,9 +138,11 @@ def return_connection_info(self) -> None:
137138
logger.debug(f"Encrypted Payload '{payload}")
138139
s.send(payload)
139140

140-
def prepare_comm_socket(self) -> None:
141+
def prepare_comm_socket(self, comm_port) -> None:
141142
"""Prepares the socket to which the server will send signal and shutdown requests."""
142-
self.comm_socket = self._select_socket()
143+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
144+
sock.bind(("0.0.0.0", comm_port))
145+
self.comm_socket = sock
143146
logger.info(
144147
f"Signal socket bound to host: "
145148
f"{self.comm_socket.getsockname()[0]}, port: {self.comm_socket.getsockname()[1]}"
@@ -212,7 +215,7 @@ def get_server_request(self) -> dict:
212215

213216
return request_info
214217

215-
def process_requests(self) -> None:
218+
def process_requests(self, comm_port) -> None:
216219
"""Waits for requests from the server and processes each when received. Currently,
217220
these will be one of a sending a signal to the corresponding kernel process (signum) or
218221
stopping the listener and exiting the kernel (shutdown).
@@ -225,7 +228,7 @@ def process_requests(self) -> None:
225228

226229
# Since this creates the communication socket, we should do this here so the socket
227230
# gets created in the sub-process/thread. This is necessary on MacOS/Python.
228-
self.return_connection_info()
231+
self.return_connection_info(comm_port)
229232

230233
while not shutdown:
231234
request = self.get_server_request()
@@ -277,10 +280,11 @@ def setup_server_listener(
277280
public_key,
278281
cluster_type,
279282
)
280-
sl.build_connection_file()
283+
comm_port = sl.build_connection_file()
281284

282285
set_start_method("fork")
283-
server_listener = Process(target=sl.process_requests)
286+
# server_listener = Process(target=sl.process_requests)
287+
server_listener = Process(target=sl.process_requests, args=(comm_port,))
284288
server_listener.start()
285289

286290

0 commit comments

Comments
 (0)