Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions tests/containers/kubernetes_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,16 +225,16 @@ def deploy(self, container_name: str) -> None:
assert len(pod_name.items) == 1
pod: kubernetes.client.models.v1_pod.V1Pod = pod_name.items[0]

p = socket_proxy.SocketProxy(exposing_contextmanager(core_v1_api, pod), "localhost", 0)
p = socket_proxy.SocketProxy(lambda: exposing_contextmanager(core_v1_api, pod), "localhost", 0)
t = threading.Thread(target=p.listen_and_serve_until_canceled)
t.start()
self.tf.defer(t, lambda thread: thread.join())
self.tf.defer(p.cancellation_token, lambda token: token.cancel())

self.port = p.get_actual_port()
LOGGER.debug(f"Listening on port {self.port}")
resp = requests.get(f"http://localhost:{self.port}")
assert resp.status_code == 200
Wait.until("Connecting to pod succeeds", 1, 30,
lambda: requests.get(f"http://localhost:{self.port}").status_code == 200)
LOGGER.debug(f"Done with portforward")


Expand Down Expand Up @@ -344,6 +344,8 @@ def until(
result: bool = ready()
except KeyboardInterrupt:
raise # quick exit if the user gets tired of waiting
except SyntaxError: # this actually won't happen, but it's good to keep in mind
raise # quick exit in cases developer obviously screwed up
except Exception as e:
exception_message = str(e)

Expand Down
58 changes: 42 additions & 16 deletions tests/containers/socket_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import logging
import socket
import select
import struct
import threading
import subprocess
import typing
from typing import Callable, ContextManager

from tests.containers.cancellation_token import CancellationToken

Expand Down Expand Up @@ -50,7 +51,7 @@ def stop(self):
class SocketProxy:
def __init__(
self,
remote_socket_factory: typing.ContextManager[socket.socket],
remote_socket_factory: Callable[..., ContextManager[socket.socket]],
local_host: str = "localhost",
local_port: int = 0,
buffer_size: int = 4096
Expand Down Expand Up @@ -81,9 +82,14 @@ def listen_and_serve_until_canceled(self):
Handles at most one client at a time. """
try:
while not self.cancellation_token.cancelled:
client_socket, addr = self.server_socket.accept()
logging.info(f"Accepted connection from {addr[0]}:{addr[1]}")
self._handle_client(client_socket)
readable, _, _ = select.select([self.server_socket, self.cancellation_token], [], [])

# ISSUE-922: socket.accept() blocks, so if cancel() did not come very fast, we'd loop over and block
if self.server_socket in readable:
client_socket, addr = self.server_socket.accept()
logging.info(f"Accepted connection from {addr[0]}:{addr[1]}")
# handle client synchronously, which means that there can be at most one at a time
self._handle_client(client_socket)
except Exception as e:
logging.exception(f"Proxying failed to listen", exc_info=e)
raise
Expand All @@ -96,27 +102,39 @@ def get_actual_port(self) -> int:
return self.server_socket.getsockname()[1]

def _handle_client(self, client_socket):
with client_socket as _, self.remote_socket_factory as remote_socket:
while True:
with client_socket as _, self.remote_socket_factory() as remote_socket:
while not self.cancellation_token.cancelled:
readable, _, _ = select.select([client_socket, remote_socket, self.cancellation_token], [], [])

if self.cancellation_token.cancelled:
break

if client_socket in readable:
data = client_socket.recv(self.buffer_size)
if not data:
break
remote_socket.send(data)

if remote_socket in readable:
data = remote_socket.recv(self.buffer_size)
try:
data = remote_socket.recv(self.buffer_size)
except ConnectionResetError:
# ISSUE-922: it seems best to propagate the error and let the client retry
# alternatively it would be necessary to resend anything already received from client_socket
logging.info(f"Reading from remote socket failed, client {client_socket.getpeername()} has been disconnected")
_rst_socket(client_socket)
break
if not data:
break
client_socket.send(data)


if __name__ == "__main__":
def _rst_socket(s: socket):
"""Closing a SO_LINGER socket will RST it
https://stackoverflow.com/questions/46264404/how-can-i-reset-a-tcp-socket-in-python
"""
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0))
s.close()


def main() -> None:
"""Sample application to show how this can work."""


Expand Down Expand Up @@ -161,13 +179,21 @@ def get_actual_port(self):
server.join()


proxy = SocketProxy(remote_socket_factory(), "localhost", 0)
proxy = SocketProxy(remote_socket_factory, "localhost", 0)
thread = threading.Thread(target=proxy.listen_and_serve_until_canceled)
thread.start()

client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client_socket.connect(("localhost", proxy.get_actual_port()))
for _ in range(2):
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client_socket.connect(("localhost", proxy.get_actual_port()))

print(client_socket.recv(1024)) # prints Hello World
print(client_socket.recv(1024)) # prints Hello World
print(client_socket.recv(1024)) # prints nothing
client_socket.close()
proxy.cancellation_token.cancel()

thread.join()


if __name__ == "__main__":
main()