Skip to content

Commit 1eb6af2

Browse files
author
Uziel Silva
committed
feat: Add proxy server and fix all unit tests
1 parent 7362233 commit 1eb6af2

File tree

9 files changed

+1062
-355
lines changed

9 files changed

+1062
-355
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ dist/
88
.idea
99
.coverage
1010
sponge_log.xml
11+
*.iml

google/cloud/sql/connector/connector.py

Lines changed: 167 additions & 103 deletions
Large diffs are not rendered by default.

google/cloud/sql/connector/enums.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class DriverMapping(Enum):
6262

6363
ASYNCPG = "POSTGRES"
6464
PG8000 = "POSTGRES"
65-
LOCAL_UNIX_SOCKET = "ANY"
65+
PSYCOPG = "POSTGRES"
6666
PYMYSQL = "MYSQL"
6767
PYTDS = "SQLSERVER"
6868

@@ -79,7 +79,7 @@ def validate_engine(driver: str, engine_version: str) -> None:
7979
the given engine.
8080
"""
8181
mapping = DriverMapping[driver.upper()]
82-
if not mapping.value == "ANY" and not engine_version.startswith(mapping.value):
82+
if not engine_version.startswith(mapping.value):
8383
raise IncompatibleDriverError(
8484
f"Database driver '{driver}' is incompatible with database "
8585
f"version '{engine_version}'. Given driver can "

google/cloud/sql/connector/proxy.py

Lines changed: 217 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -14,127 +14,253 @@
1414
limitations under the License.
1515
"""
1616

17+
from __future__ import annotations
18+
19+
from abc import ABC
20+
from abc import abstractmethod
1721
import asyncio
22+
from functools import partial
23+
import logging
1824
import os
1925
from pathlib import Path
20-
import socket
21-
import selectors
22-
import ssl
26+
from typing import Callable, List
2327

24-
from google.cloud.sql.connector.exceptions import LocalProxyStartupError
28+
logger = logging.getLogger(name=__name__)
2529

26-
LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760
2730

31+
class BaseProxyProtocol(asyncio.Protocol):
32+
"""
33+
A protocol to proxy data between two transports.
34+
"""
2835

29-
class Proxy:
30-
"""Creates an "accept loop" async task which will open the unix server socket and listen for new connections."""
36+
def __init__(self, proxy: Proxy):
37+
super().__init__()
38+
self.proxy = proxy
39+
self._buffer = bytearray()
40+
self._target: asyncio.Transport | None = None
41+
self.transport: asyncio.Transport | None = None
42+
self._cached: List[bytes] = []
43+
logger.debug(f"__init__ {self}")
44+
45+
def connection_made(self, transport):
46+
logger.debug(f"connection_made {self}")
47+
self.transport = transport
48+
49+
def data_received(self, data):
50+
if self._target is None:
51+
self._cached.append(data)
52+
else:
53+
self._target.write(data)
54+
55+
def set_target(self, target: asyncio.Transport):
56+
logger.debug(f"set_target {self}")
57+
self._target = target
58+
if self._cached:
59+
self._target.writelines(self._cached)
60+
self._cached = []
61+
62+
def eof_received(self):
63+
logger.debug(f"eof_received {self}")
64+
if self._target is not None:
65+
self._target.write_eof()
66+
67+
def connection_lost(self, exc: Exception | None):
68+
logger.debug(f"connection_lost {exc} {self}")
69+
if self._target is not None:
70+
self._target.close()
71+
72+
73+
class ProxyClientConnection:
74+
"""
75+
Holds all of the tasks and details for a client proxy
76+
"""
3177

3278
def __init__(
3379
self,
34-
connector,
35-
instance_connection_string: str,
36-
socket_path: str,
37-
loop: asyncio.AbstractEventLoop,
38-
**kwargs
39-
) -> None:
40-
"""Keeps track of all the async tasks and starts the accept loop for new connections.
41-
42-
Args:
43-
connector (Connector): The instance where this Proxy class was created.
80+
client_transport: asyncio.Transport,
81+
client_protocol: ClientToServerProtocol,
82+
):
83+
self.client_transport = client_transport
84+
self.client_protocol = client_protocol
85+
self.server_transport: asyncio.Transport | None = None
86+
self.server_protocol: ServerToClientProtocol | None = None
87+
self.task: asyncio.Task | None = None
88+
89+
def close(self):
90+
logger.debug(f"closing {self}")
91+
if self.client_transport is not None:
92+
self._close_transport(self.client_transport)
93+
if self.server_transport is not None:
94+
self._close_transport(self.server_transport)
95+
96+
def _close_transport(self, transport:asyncio.Transport):
97+
if transport.is_closing():
98+
return
99+
if transport.can_write_eof():
100+
transport.write_eof()
101+
else:
102+
transport.close()
103+
104+
class ClientToServerProtocol(BaseProxyProtocol):
105+
"""
106+
Protocol to copy bytes from the unix socket client to the database server
107+
"""
108+
109+
def __init__(self, proxy: Proxy):
110+
super().__init__(proxy)
111+
self._buffer = bytearray()
112+
self._target: asyncio.Transport | None = None
113+
logger.debug(f"__init__ {self}")
114+
115+
def connection_made(self, transport):
116+
# When a connection is made, open the server connection
117+
super().connection_made(transport)
118+
self.proxy._handle_client_connection(transport, self)
44119

45-
instance_connection_string (str): The instance connection name of the
46-
Cloud SQL instance to connect to. Takes the form of
47-
"project-id:region:instance-name"
48120

49-
Example: "my-project:us-central1:my-instance"
121+
class ServerToClientProtocol(BaseProxyProtocol):
122+
"""
123+
Protocol to copy bytes from the database server to the client socket
124+
"""
50125

51-
socket_path (str): A system path that is going to be used to store the socket.
126+
def __init__(self, proxy: Proxy, cconn: ProxyClientConnection):
127+
super().__init__(proxy)
128+
self._buffer = bytearray()
129+
self._target = cconn.client_transport
130+
self._client_protocol = cconn.client_protocol
131+
logger.debug(f"__init__ {self}")
52132

53-
loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks.
133+
def connection_made(self, transport):
134+
super().connection_made(transport)
135+
self._client_protocol.set_target(transport)
54136

55-
**kwargs: Any driver-specific arguments to pass to the underlying
56-
driver .connect call.
137+
def connection_lost(self, exc: Exception | None):
138+
super().connection_lost(exc)
139+
self.proxy._handle_server_connection_lost()
140+
141+
class ServerConnectionFactory(ABC):
142+
"""
143+
ServerConnectionFactory is an abstract class that provides connections to the service.
144+
"""
145+
@abstractmethod
146+
async def connect(self, protocol_fn: Callable[[], asyncio.Protocol]):
147+
"""
148+
Establishes a connection to the server and configures it to use the protocol
149+
returned from protocol_fn, with asyncio.EventLoop.create_connection().
150+
:param protocol_fn: the protocol function
151+
:return: None
57152
"""
58-
self._connection_tasks = []
59-
self._addr = instance_connection_string
60-
self._kwargs = kwargs
61-
self._connector = connector
153+
pass
62154

63-
unix_socket = None
155+
class Proxy:
156+
"""
157+
A class to represent a local Unix socket proxy for a Cloud SQL instance.
158+
This class manages a Unix socket that listens for incoming connections and
159+
proxies them to a Cloud SQL instance.
160+
"""
64161

65-
try:
66-
path_parts = socket_path.rsplit('/', 1)
67-
parent_directory = '/'.join(path_parts[:-1])
162+
def __init__(
163+
self,
164+
unix_socket_path: str,
165+
server_connection_factory: ServerConnectionFactory,
166+
loop: asyncio.AbstractEventLoop,
167+
):
168+
"""
169+
Creates a new Proxy
170+
:param unix_socket_path: the path to listen for the proxy connection
171+
:param loop: The event loop
172+
:param instance_connect: A function that will establish the async connection to the server
173+
174+
The instance_connect function is an asynchronous function that should set up a new connection.
175+
It takes one argument - another function that
176+
"""
177+
self.unix_socket_path = unix_socket_path
178+
self.alive = True
179+
self._loop = loop
180+
self._server: asyncio.AbstractServer | None = None
181+
self._client_connections: set[ProxyClientConnection] = set()
182+
self._server_connection_factory = server_connection_factory
68183

69-
desired_path = Path(parent_directory)
70-
desired_path.mkdir(parents=True, exist_ok=True)
184+
async def start(self) -> None:
185+
"""Starts the Unix socket server."""
186+
if os.path.exists(self.unix_socket_path):
187+
os.remove(self.unix_socket_path)
71188

72-
if os.path.exists(socket_path):
73-
os.remove(socket_path)
189+
parent_dir = Path(self.unix_socket_path).parent
190+
parent_dir.mkdir(parents=True, exist_ok=True)
74191

75-
unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
192+
def new_protocol() -> ClientToServerProtocol:
193+
return ClientToServerProtocol(self)
76194

77-
unix_socket.bind(socket_path)
78-
unix_socket.listen(1)
79-
unix_socket.setblocking(False)
80-
os.chmod(socket_path, 0o600)
195+
logger.debug(f"Socket path: {self.unix_socket_path}")
196+
self._server = await self._loop.create_unix_server(
197+
new_protocol, path=self.unix_socket_path
198+
)
199+
self._loop.create_task(self._server.serve_forever())
81200

82-
self._task = loop.create_task(self.accept_loop(unix_socket, socket_path, loop))
201+
def _handle_client_connection(
202+
self,
203+
client_transport: asyncio.Transport,
204+
client_protocol: ClientToServerProtocol,
205+
) -> None:
206+
"""
207+
Register a new client connection and initiate the task to create a database connection.
208+
This is called by ClientToServerProtocol.connection_made
83209
84-
except Exception:
85-
raise LocalProxyStartupError(
86-
'Local UNIX socket based proxy was not able to get started.'
87-
)
210+
:param client_transport: the client transport for the client unix socket
211+
:param client_protocol: the instance for the
212+
:return: None
213+
"""
214+
conn = ProxyClientConnection(client_transport, client_protocol)
215+
self._client_connections.add(conn)
216+
conn.task = self._loop.create_task(self._create_db_instance_connection(conn))
217+
conn.task.add_done_callback(lambda _: self._client_connections.discard(conn))
88218

89-
async def accept_loop(
219+
def _handle_server_connection_lost(
90220
self,
91-
unix_socket,
92-
socket_path: str,
93-
loop: asyncio.AbstractEventLoop
94-
) -> asyncio.Task:
95-
"""Starts a UNIX based local proxy for transporting messages through
96-
the SSL Socket, and waits until there is a new connection to accept, to register it
97-
and keep track of it.
221+
) -> None:
222+
"""
223+
Closes the proxy server if the connection to the server is lost
98224
99-
Args:
100-
socket_path: A system path that is going to be used to store the socket.
225+
:return: None
226+
"""
227+
logger.debug(f"Closing proxy server due to lost connection")
228+
self._loop.create_task(self.close())
229+
230+
async def _create_db_instance_connection(self, conn: ProxyClientConnection) -> None:
231+
"""
232+
Manages a single proxy connection from a client to the Cloud SQL instance.
233+
"""
234+
try:
235+
logger.debug("_proxy_connection() started")
236+
new_protocol = partial(ServerToClientProtocol, self, conn)
237+
238+
# Establish connection to the database
239+
await self._server_connection_factory.connect(new_protocol)
240+
logger.debug("_proxy_connection() succeeded")
101241

102-
loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks.
242+
except Exception as e:
243+
logger.error(f"Error handling proxy connection: {e}")
244+
await self.close()
245+
raise e
103246

104-
Raises:
105-
LocalProxyStartupError: Local UNIX socket based proxy was not able to
106-
get started.
247+
async def close(self) -> None:
107248
"""
108-
print("on accept loop")
109-
while True:
110-
client, _ = await loop.sock_accept(unix_socket)
111-
self._connection_tasks.append(loop.create_task(self.client_socket(client, unix_socket, socket_path, loop)))
249+
Shuts down the proxy server and cleans up resources.
250+
"""
251+
logger.info(f"Closing Unix socket proxy at {self.unix_socket_path}")
112252

113-
async def close_async(self):
114-
proxy_task = asyncio.gather(self._task)
115-
try:
116-
await asyncio.wait_for(proxy_task, timeout=0.1)
117-
except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError):
118-
pass # This task runs forever so it is expected to throw this exception
253+
if self._server:
254+
self._server.close()
255+
await self._server.wait_closed()
119256

257+
if self._client_connections:
258+
for conn in list(self._client_connections):
259+
conn.close()
260+
await asyncio.wait([c.task for c in self._client_connections if c.task is not None], timeout=0.1)
120261

121-
async def client_socket(
122-
self, client, unix_socket, socket_path, loop
123-
):
124-
try:
125-
ssl_sock = self._connector.connect(
126-
self._addr,
127-
'local_unix_socket',
128-
**self._kwargs
129-
)
130-
while True:
131-
data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE)
132-
if not data:
133-
client.close()
134-
break
135-
ssl_sock.sendall(data)
136-
response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE)
137-
await loop.sock_sendall(client, response)
138-
finally:
139-
client.close()
140-
os.remove(socket_path) # Clean up the socket file
262+
if os.path.exists(self.unix_socket_path):
263+
os.remove(self.unix_socket_path)
264+
265+
logger.info(f"Unix socket proxy for {self.unix_socket_path} closed.")
266+
self.alive = False

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ exclude = ['docs/*', 'samples/*']
8282

8383
[tool.pytest.ini_options]
8484
asyncio_mode = "auto"
85+
log_cli = true
86+
log_cli_level = "DEBUG"
87+
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
88+
log_cli_date_format = "%Y-%m-%d %H:%M:%S.%f"
8589

8690
[tool.ruff.lint]
8791
extend-select = ["I"]

0 commit comments

Comments
 (0)