Skip to content

Commit 28c1c40

Browse files
author
Uziel Silva
committed
fix(main): fix feedback PR
Changelog: - Make local_socket_path configurable - Set right file permissions - Handle exceptions properly - Use asyncio and the main loop to stop the local proxy and clear the file when the connector is stopped
1 parent 89719a4 commit 28c1c40

File tree

5 files changed

+166
-69
lines changed

5 files changed

+166
-69
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,12 @@
4444
from google.cloud.sql.connector.resolver import DnsResolver
4545
from google.cloud.sql.connector.utils import format_database_user
4646
from google.cloud.sql.connector.utils import generate_keys
47+
from google.cloud.sql.connector.proxy import start_local_proxy
4748

4849
logger = logging.getLogger(name=__name__)
4950

5051
ASYNC_DRIVERS = ["asyncpg"]
52+
LOCAL_PROXY_DRIVERS = ["psycopg"]
5153
SERVER_PROXY_PORT = 3307
5254
_DEFAULT_SCHEME = "https://"
5355
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
@@ -383,7 +385,7 @@ async def connect_async(
383385
# async drivers are unblocking and can be awaited directly
384386
if driver in ASYNC_DRIVERS:
385387
return await connector(
386-
ip_address,
388+
host,
387389
await conn_info.create_ssl_context(enable_iam_auth),
388390
**kwargs,
389391
)
@@ -393,14 +395,26 @@ async def connect_async(
393395
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
394396
server_hostname=ip_address,
395397
)
398+
399+
host = ip_address
400+
# start local proxy if driver needs it
401+
if driver in LOCAL_PROXY_DRIVERS:
402+
local_socket_path = kwargs.pop("local_socket_path", "/tmp/connector-socket")
403+
host = local_socket_path
404+
start_local_proxy(
405+
sock,
406+
socket_path=f"{local_socket_path}/.s.PGSQL.{SERVER_PROXY_PORT}",
407+
loop=self._loop
408+
)
409+
396410
# If this connection was opened using a domain name, then store it
397411
# for later in case we need to forcibly close it on failover.
398412
if conn_info.conn_name.domain_name:
399413
monitored_cache.sockets.append(sock)
400414
# Synchronous drivers are blocking and run using executor
401415
connect_partial = partial(
402416
connector,
403-
ip_address,
417+
host,
404418
sock,
405419
**kwargs,
406420
)

google/cloud/sql/connector/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,9 @@ class CacheClosedError(Exception):
8484
Exception to be raised when a ConnectionInfoCache can not be accessed after
8585
it is closed.
8686
"""
87+
88+
89+
class LocalProxyStartupError(Exception):
90+
"""
91+
Exception to be raised when a the local UNIX-socket based proxy can not be started.
92+
"""

google/cloud/sql/connector/proxy.py

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,48 +16,75 @@
1616

1717
import socket
1818
import os
19-
import threading
19+
import ssl
20+
import asyncio
2021
from pathlib import Path
22+
from typing import Optional
23+
24+
from google.cloud.sql.connector.exceptions import LocalProxyStartupError
2125

2226
SERVER_PROXY_PORT = 3307
27+
LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760
2328

2429
def start_local_proxy(
25-
ssl_sock,
26-
socket_path,
30+
ssl_sock: ssl.SSLSocket,
31+
socket_path: Optional[str] = "/tmp/connector-socket",
32+
loop: Optional[asyncio.AbstractEventLoop] = None,
2733
):
28-
path_parts = socket_path.rsplit('/', 1)
29-
parent_directory = '/'.join(path_parts[:-1])
34+
"""Helper function to start a UNIX based local proxy for
35+
transport messages through the SSL Socket.
36+
37+
Args:
38+
ssl_sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL
39+
server CA cert and ephemeral cert.
40+
socket_path: A system path that is going to be used to store the socket.
41+
loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks.
42+
43+
Raises:
44+
LocalProxyStartupError: Local UNIX socket based proxy was not able to
45+
get started.
46+
"""
47+
unix_socket = None
3048

31-
desired_path = Path(parent_directory)
32-
desired_path.mkdir(parents=True, exist_ok=True)
49+
try:
50+
path_parts = socket_path.rsplit('/', 1)
51+
parent_directory = '/'.join(path_parts[:-1])
3352

34-
if os.path.exists(socket_path):
35-
os.remove(socket_path)
36-
conn_unix = None
37-
unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
53+
desired_path = Path(parent_directory)
54+
desired_path.mkdir(parents=True, exist_ok=True)
3855

39-
unix_socket.bind(socket_path)
40-
unix_socket.listen(1)
56+
if os.path.exists(socket_path):
57+
os.remove(socket_path)
58+
unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
4159

42-
threading.Thread(target=local_communication, args=(unix_socket, ssl_sock, socket_path)).start()
60+
unix_socket.bind(socket_path)
61+
unix_socket.listen(1)
62+
unix_socket.setblocking(False)
63+
os.chmod(socket_path, 0o600)
64+
except Exception:
65+
raise LocalProxyStartupError(
66+
'Local UNIX socket based proxy was not able to get started.'
67+
)
4368

69+
loop.create_task(local_communication(unix_socket, ssl_sock, socket_path, loop))
4470

45-
def local_communication(
46-
unix_socket, ssl_sock, socket_path
71+
72+
async def local_communication(
73+
unix_socket, ssl_sock, socket_path, loop
4774
):
48-
try:
49-
conn_unix, addr_unix = unix_socket.accept()
50-
51-
while True:
52-
data = conn_unix.recv(10485760)
53-
if not data:
54-
break
55-
ssl_sock.sendall(data)
56-
response = ssl_sock.recv(10485760)
57-
conn_unix.sendall(response)
58-
59-
finally:
60-
if conn_unix is not None:
61-
conn_unix.close()
62-
unix_socket.close()
63-
os.remove(socket_path) # Clean up the socket file
75+
try:
76+
client, _ = await loop.sock_accept(unix_socket)
77+
78+
while True:
79+
data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE)
80+
if not data:
81+
client.close()
82+
break
83+
ssl_sock.sendall(data)
84+
response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE)
85+
await loop.sock_sendall(client, response)
86+
except Exception:
87+
pass
88+
finally:
89+
client.close()
90+
os.remove(socket_path) # Clean up the socket file

google/cloud/sql/connector/psycopg.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,12 @@
2525

2626

2727
def connect(
28-
ip_address: str, sock: ssl.SSLSocket, **kwargs: Any
28+
host: str, sock: ssl.SSLSocket, **kwargs: Any
2929
) -> "psycopg.Connection":
3030
"""Helper function to create a psycopg DB-API connection object.
3131
3232
Args:
33-
ip_address (str): A string containing an IP address for the Cloud SQL
34-
instance.
33+
host (str): A string containing the socket path used by the local proxy.
3534
sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL
3635
server CA cert and ephemeral cert.
3736
kwargs: Additional arguments to pass to the psycopg connect method.
@@ -44,10 +43,7 @@ def connect(
4443
ImportError: The psycopg module cannot be imported.
4544
"""
4645
try:
47-
from psycopg.rows import dict_row
4846
from psycopg import Connection
49-
import threading
50-
from google.cloud.sql.connector.proxy import start_local_proxy
5147
except ImportError:
5248
raise ImportError(
5349
'Unable to import module "psycopg." Please install and try again.'
@@ -59,14 +55,9 @@ def connect(
5955

6056
kwargs.pop("timeout", None)
6157

62-
start_local_proxy(sock, f"/tmp/connector-socket/.s.PGSQL.3307")
63-
6458
conn = Connection.connect(
65-
f"host=/tmp/connector-socket port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require",
66-
autocommit=True,
67-
row_factory=dict_row,
59+
f"host={host} port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require",
6860
**kwargs
6961
)
7062

71-
conn.autocommit = True
7263
return conn

tests/system/test_psycopg_connection.py

Lines changed: 82 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,84 @@
1818
import os
1919

2020
# [START cloud_sql_connector_postgres_psycopg]
21+
from typing import Union
22+
23+
import sqlalchemy
2124

2225
from google.cloud.sql.connector import Connector
2326
from google.cloud.sql.connector import DefaultResolver
27+
from google.cloud.sql.connector import DnsResolver
28+
29+
30+
def create_sqlalchemy_engine(
31+
instance_connection_name: str,
32+
user: str,
33+
password: str,
34+
db: str,
35+
ip_type: str = "public",
36+
refresh_strategy: str = "background",
37+
resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver,
38+
) -> tuple[sqlalchemy.engine.Engine, Connector]:
39+
"""Creates a connection pool for a Cloud SQL instance and returns the pool
40+
and the connector. Callers are responsible for closing the pool and the
41+
connector.
42+
43+
A sample invocation looks like:
44+
45+
engine, connector = create_sqlalchemy_engine(
46+
inst_conn_name,
47+
user,
48+
password,
49+
db,
50+
)
51+
with engine.connect() as conn:
52+
time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone()
53+
conn.commit()
54+
curr_time = time[0]
55+
# do something with query result
56+
connector.close()
57+
58+
Args:
59+
instance_connection_name (str):
60+
The instance connection name specifies the instance relative to the
61+
project and region. For example: "my-project:my-region:my-instance"
62+
user (str):
63+
The database user name, e.g., root
64+
password (str):
65+
The database user's password, e.g., secret-password
66+
db (str):
67+
The name of the database, e.g., mydb
68+
ip_type (str):
69+
The IP type of the Cloud SQL instance to connect to. Can be one
70+
of "public", "private", or "psc".
71+
refresh_strategy (Optional[str]):
72+
Refresh strategy for the Cloud SQL Connector. Can be one of "lazy"
73+
or "background". For serverless environments use "lazy" to avoid
74+
errors resulting from CPU being throttled.
75+
resolver (Optional[google.cloud.sql.connector.DefaultResolver]):
76+
Resolver class for resolving instance connection name. Use
77+
google.cloud.sql.connector.DnsResolver when resolving DNS domain
78+
names or google.cloud.sql.connector.DefaultResolver for regular
79+
instance connection names ("my-project:my-region:my-instance").
80+
"""
81+
connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver)
82+
83+
# create SQLAlchemy connection pool
84+
engine = sqlalchemy.create_engine(
85+
"postgresql+psycopg://",
86+
creator=lambda: connector.connect(
87+
instance_connection_name,
88+
"psycopg",
89+
user=user,
90+
password=password,
91+
db=db,
92+
local_socket_path="/tmp/conn",
93+
ip_type=ip_type, # can be "public", "private" or "psc"
94+
autocommit=True,
95+
),
96+
)
97+
return engine, connector
2498

25-
from sqlalchemy.dialects.postgresql.base import PGDialect
26-
PGDialect._get_server_version_info = lambda *args: (9, 2)
2799

28100
# [END cloud_sql_connector_postgres_psycopg]
29101

@@ -36,25 +108,12 @@ def test_psycopg_connection() -> None:
36108
db = os.environ["POSTGRES_DB"]
37109
ip_type = os.environ.get("IP_TYPE", "public")
38110

39-
connector = Connector(refresh_strategy="background", resolver=DefaultResolver)
40-
41-
pool = connector.connect(
42-
inst_conn_name,
43-
"psycopg",
44-
user=user,
45-
password=password,
46-
db=db,
47-
ip_type=ip_type, # can be "public", "private" or "psc"
111+
engine, connector = create_sqlalchemy_engine(
112+
inst_conn_name, user, password, db, ip_type
48113
)
49-
50-
with pool as conn:
51-
52-
# Open a cursor to perform database operations
53-
with conn.cursor() as cur:
54-
55-
# Query the database and obtain data as Python objects.
56-
cur.execute("SELECT NOW()")
57-
curr_time = cur.fetchone()["now"]
58-
assert type(curr_time) is datetime
59-
60-
114+
with engine.connect() as conn:
115+
time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone()
116+
conn.commit()
117+
curr_time = time[0]
118+
assert type(curr_time) is datetime
119+
connector.close()

0 commit comments

Comments
 (0)