From 1bec78c0fe04fd762ae35a6d38bfbc2e9fe1e9ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jiri=20Dan=C4=9Bk?= Date: Fri, 21 Mar 2025 13:48:34 +0100 Subject: [PATCH] ISSUE-922: chore(tests/containers): implement retry if port-forwarding fails --- tests/containers/kubernetes_utils.py | 8 ++-- tests/containers/socket_proxy.py | 58 ++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/tests/containers/kubernetes_utils.py b/tests/containers/kubernetes_utils.py index 4f437a312d..7fc464074b 100644 --- a/tests/containers/kubernetes_utils.py +++ b/tests/containers/kubernetes_utils.py @@ -225,7 +225,7 @@ def deploy(self, container_name: str) -> None: assert len(pod_name.items) == 1 pod: kubernetes.client.models.v1_pod.V1Pod = pod_name.items[0] - p = socket_proxy.SocketProxy(exposing_contextmanager(core_v1_api, pod), "localhost", 0) + p = socket_proxy.SocketProxy(lambda: exposing_contextmanager(core_v1_api, pod), "localhost", 0) t = threading.Thread(target=p.listen_and_serve_until_canceled) t.start() self.tf.defer(t, lambda thread: thread.join()) @@ -233,8 +233,8 @@ def deploy(self, container_name: str) -> None: self.port = p.get_actual_port() LOGGER.debug(f"Listening on port {self.port}") - resp = requests.get(f"http://localhost:{self.port}") - assert resp.status_code == 200 + Wait.until("Connecting to pod succeeds", 1, 30, + lambda: requests.get(f"http://localhost:{self.port}").status_code == 200) LOGGER.debug(f"Done with portforward") @@ -344,6 +344,8 @@ def until( result: bool = ready() except KeyboardInterrupt: raise # quick exit if the user gets tired of waiting + except SyntaxError: # this actually won't happen, but it's good to keep in mind + raise # quick exit in cases developer obviously screwed up except Exception as e: exception_message = str(e) diff --git a/tests/containers/socket_proxy.py b/tests/containers/socket_proxy.py index 77fc348579..7415864fcc 100644 --- a/tests/containers/socket_proxy.py +++ b/tests/containers/socket_proxy.py @@ -4,9 +4,10 @@ import logging import socket import select +import struct import threading import subprocess -import typing +from typing import Callable, ContextManager from tests.containers.cancellation_token import CancellationToken @@ -50,7 +51,7 @@ def stop(self): class SocketProxy: def __init__( self, - remote_socket_factory: typing.ContextManager[socket.socket], + remote_socket_factory: Callable[..., ContextManager[socket.socket]], local_host: str = "localhost", local_port: int = 0, buffer_size: int = 4096 @@ -81,9 +82,14 @@ def listen_and_serve_until_canceled(self): Handles at most one client at a time. """ try: while not self.cancellation_token.cancelled: - client_socket, addr = self.server_socket.accept() - logging.info(f"Accepted connection from {addr[0]}:{addr[1]}") - self._handle_client(client_socket) + readable, _, _ = select.select([self.server_socket, self.cancellation_token], [], []) + + # ISSUE-922: socket.accept() blocks, so if cancel() did not come very fast, we'd loop over and block + if self.server_socket in readable: + client_socket, addr = self.server_socket.accept() + logging.info(f"Accepted connection from {addr[0]}:{addr[1]}") + # handle client synchronously, which means that there can be at most one at a time + self._handle_client(client_socket) except Exception as e: logging.exception(f"Proxying failed to listen", exc_info=e) raise @@ -96,13 +102,10 @@ def get_actual_port(self) -> int: return self.server_socket.getsockname()[1] def _handle_client(self, client_socket): - with client_socket as _, self.remote_socket_factory as remote_socket: - while True: + with client_socket as _, self.remote_socket_factory() as remote_socket: + while not self.cancellation_token.cancelled: readable, _, _ = select.select([client_socket, remote_socket, self.cancellation_token], [], []) - if self.cancellation_token.cancelled: - break - if client_socket in readable: data = client_socket.recv(self.buffer_size) if not data: @@ -110,13 +113,28 @@ def _handle_client(self, client_socket): remote_socket.send(data) if remote_socket in readable: - data = remote_socket.recv(self.buffer_size) + try: + data = remote_socket.recv(self.buffer_size) + except ConnectionResetError: + # ISSUE-922: it seems best to propagate the error and let the client retry + # alternatively it would be necessary to resend anything already received from client_socket + logging.info(f"Reading from remote socket failed, client {client_socket.getpeername()} has been disconnected") + _rst_socket(client_socket) + break if not data: break client_socket.send(data) -if __name__ == "__main__": +def _rst_socket(s: socket): + """Closing a SO_LINGER socket will RST it + https://stackoverflow.com/questions/46264404/how-can-i-reset-a-tcp-socket-in-python + """ + s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0)) + s.close() + + +def main() -> None: """Sample application to show how this can work.""" @@ -161,13 +179,21 @@ def get_actual_port(self): server.join() - proxy = SocketProxy(remote_socket_factory(), "localhost", 0) + proxy = SocketProxy(remote_socket_factory, "localhost", 0) thread = threading.Thread(target=proxy.listen_and_serve_until_canceled) thread.start() - client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - client_socket.connect(("localhost", proxy.get_actual_port())) + for _ in range(2): + client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client_socket.connect(("localhost", proxy.get_actual_port())) - print(client_socket.recv(1024)) # prints Hello World + print(client_socket.recv(1024)) # prints Hello World + print(client_socket.recv(1024)) # prints nothing + client_socket.close() + proxy.cancellation_token.cancel() thread.join() + + +if __name__ == "__main__": + main()