Skip to content

Commit 2c802ae

Browse files
author
Uziel Silva
committed
Changelog:
- Add proxy for connections that can only be made through an unix socket, to support the TLS connection - Add support for psycopg, using the proxy server - Add unit and integration tests - Update docs
1 parent 113c684 commit 2c802ae

File tree

9 files changed

+239
-4
lines changed

9 files changed

+239
-4
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ The Cloud SQL Python Connector is a package to be used alongside a database driv
4242
Currently supported drivers are:
4343
- [`pymysql`](https://github.com/PyMySQL/PyMySQL) (MySQL)
4444
- [`pg8000`](https://github.com/tlocke/pg8000) (PostgreSQL)
45+
- [`psycopg`](https://github.com/psycopg/psycopg) (PostgreSQL)
4546
- [`asyncpg`](https://github.com/MagicStack/asyncpg) (PostgreSQL)
4647
- [`pytds`](https://github.com/denisenkom/pytds) (SQL Server)
4748

@@ -587,7 +588,7 @@ async def main():
587588
# acquire connection and query Cloud SQL database
588589
async with pool.acquire() as conn:
589590
res = await conn.fetch("SELECT NOW()")
590-
591+
591592
# close Connector
592593
await connector.close_async()
593594
```

google/cloud/sql/connector/connector.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from google.cloud.sql.connector.lazy import LazyRefreshCache
3838
from google.cloud.sql.connector.monitored_cache import MonitoredCache
3939
import google.cloud.sql.connector.pg8000 as pg8000
40+
import google.cloud.sql.connector.psycopg as psycopg
4041
import google.cloud.sql.connector.pymysql as pymysql
4142
import google.cloud.sql.connector.pytds as pytds
4243
from google.cloud.sql.connector.resolver import DefaultResolver
@@ -230,7 +231,7 @@ def connect(
230231
Example: "my-project:us-central1:my-instance"
231232
232233
driver (str): A string representing the database driver to connect
233-
with. Supported drivers are pymysql, pg8000, and pytds.
234+
with. Supported drivers are pymysql, pg8000, psycopg, and pytds.
234235
235236
**kwargs: Any driver-specific arguments to pass to the underlying
236237
driver .connect call.
@@ -266,7 +267,8 @@ async def connect_async(
266267
Example: "my-project:us-central1:my-instance"
267268
268269
driver (str): A string representing the database driver to connect
269-
with. Supported drivers are pymysql, asyncpg, pg8000, and pytds.
270+
with. Supported drivers are pymysql, asyncpg, pg8000, psycopg, and
271+
pytds.
270272
271273
**kwargs: Any driver-specific arguments to pass to the underlying
272274
driver .connect call.
@@ -278,7 +280,7 @@ async def connect_async(
278280
ValueError: Connection attempt with built-in database authentication
279281
and then subsequent attempt with IAM database authentication.
280282
KeyError: Unsupported database driver Must be one of pymysql, asyncpg,
281-
pg8000, and pytds.
283+
pg8000, psycopg, and pytds.
282284
"""
283285
if self._keys is None:
284286
self._keys = asyncio.create_task(generate_keys())
@@ -332,6 +334,7 @@ async def connect_async(
332334
connect_func = {
333335
"pymysql": pymysql.connect,
334336
"pg8000": pg8000.connect,
337+
"psycopg": psycopg.connect,
335338
"asyncpg": asyncpg.connect,
336339
"pytds": pytds.connect,
337340
}

google/cloud/sql/connector/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class DriverMapping(Enum):
6262

6363
ASYNCPG = "POSTGRES"
6464
PG8000 = "POSTGRES"
65+
PSYCOPG = "POSTGRES"
6566
PYMYSQL = "MYSQL"
6667
PYTDS = "SQLSERVER"
6768

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import socket
18+
import os
19+
import threading
20+
21+
SERVER_PROXY_PORT = 3307
22+
23+
def start_local_proxy(
24+
ssl_sock,
25+
socket_path,
26+
):
27+
if os.path.exists(socket_path):
28+
os.remove(socket_path)
29+
conn_unix = None
30+
unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
31+
32+
unix_socket.bind(socket_path)
33+
unix_socket.listen(1)
34+
35+
threading.Thread(target=local_communication, args=(unix_socket, ssl_sock, socket_path)).start()
36+
37+
38+
def local_communication(
39+
unix_socket, ssl_sock, socket_path
40+
):
41+
try:
42+
conn_unix, addr_unix = unix_socket.accept()
43+
44+
while True:
45+
data = conn_unix.recv(10485760)
46+
if not data:
47+
break
48+
ssl_sock.sendall(data)
49+
response = ssl_sock.recv(10485760)
50+
conn_unix.sendall(response)
51+
52+
finally:
53+
if conn_unix is not None:
54+
conn_unix.close()
55+
unix_socket.close()
56+
os.remove(socket_path) # Clean up the socket file
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import ssl
18+
from typing import Any, TYPE_CHECKING
19+
import threading
20+
21+
SERVER_PROXY_PORT = 3307
22+
23+
if TYPE_CHECKING:
24+
import psycopg
25+
26+
27+
def connect(
28+
ip_address: str, sock: ssl.SSLSocket, **kwargs: Any
29+
) -> "psycopg.Connection":
30+
"""Helper function to create a psycopg DB-API connection object.
31+
32+
Args:
33+
ip_address (str): A string containing an IP address for the Cloud SQL
34+
instance.
35+
sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL
36+
server CA cert and ephemeral cert.
37+
kwargs: Additional arguments to pass to the psycopg connect method.
38+
39+
Returns:
40+
psycopg.Connection: A psycopg connection to the Cloud SQL
41+
instance.
42+
43+
Raises:
44+
ImportError: The psycopg module cannot be imported.
45+
"""
46+
try:
47+
from psycopg.rows import dict_row
48+
from psycopg import Connection
49+
import threading
50+
from google.cloud.sql.connector.proxy import start_local_proxy
51+
except ImportError:
52+
raise ImportError(
53+
'Unable to import module "psycopg." Please install and try again.'
54+
)
55+
56+
user = kwargs.pop("user")
57+
db = kwargs.pop("db")
58+
passwd = kwargs.pop("password", None)
59+
60+
kwargs.pop("timeout", None)
61+
62+
start_local_proxy(sock, f"/tmp/connector-socket/.s.PGSQL.3307")
63+
64+
conn = Connection.connect(
65+
f"host=/tmp/connector-socket port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require",
66+
autocommit=True,
67+
row_factory=dict_row,
68+
**kwargs
69+
)
70+
71+
conn.autocommit = True
72+
return conn

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Changelog = "https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/b
5959
[project.optional-dependencies]
6060
pymysql = ["PyMySQL>=1.1.0"]
6161
pg8000 = ["pg8000>=1.31.1"]
62+
psycopg = ["psycopg>=3.2.9"]
6263
pytds = ["python-tds>=1.15.0"]
6364
asyncpg = ["asyncpg>=0.30.0"]
6465

requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ sqlalchemy-pytds==1.0.2
77
sqlalchemy-stubs==0.4
88
PyMySQL==1.1.1
99
pg8000==1.31.2
10+
psycopg[binary]==3.2.9
1011
asyncpg==0.30.0
1112
python-tds==1.16.1
1213
aioresponses==0.7.8
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from datetime import datetime
18+
import os
19+
20+
# [START cloud_sql_connector_postgres_psycopg]
21+
22+
from google.cloud.sql.connector import Connector
23+
from google.cloud.sql.connector import DefaultResolver
24+
25+
from sqlalchemy.dialects.postgresql.base import PGDialect
26+
PGDialect._get_server_version_info = lambda *args: (9, 2)
27+
28+
# [END cloud_sql_connector_postgres_psycopg]
29+
30+
31+
def test_psycopg_connection() -> None:
32+
"""Basic test to get time from database."""
33+
inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"]
34+
user = os.environ["POSTGRES_USER"]
35+
password = os.environ["POSTGRES_PASS"]
36+
db = os.environ["POSTGRES_DB"]
37+
ip_type = os.environ.get("IP_TYPE", "public")
38+
39+
connector = Connector(refresh_strategy="background", resolver=DefaultResolver)
40+
41+
pool = connector.connect(
42+
inst_conn_name,
43+
"psycopg",
44+
user=user,
45+
password=password,
46+
db=db,
47+
ip_type=ip_type, # can be "public", "private" or "psc"
48+
)
49+
50+
with pool as conn:
51+
52+
# Open a cursor to perform database operations
53+
with conn.cursor() as cur:
54+
55+
# Query the database and obtain data as Python objects.
56+
cur.execute("SELECT NOW()")
57+
curr_time = cur.fetchone()["now"]
58+
assert type(curr_time) is datetime
59+
60+

tests/unit/test_psycopg.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import socket
18+
import ssl
19+
from typing import Any
20+
21+
from mock import patch, PropertyMock
22+
import pytest
23+
24+
from google.cloud.sql.connector.psycopg import connect
25+
26+
27+
@pytest.mark.usefixtures("proxy_server")
28+
async def test_psycopg(context: ssl.SSLContext, kwargs: Any) -> None:
29+
"""Test to verify that psycopg gets to proper connection call."""
30+
ip_addr = "127.0.0.1"
31+
sock = context.wrap_socket(
32+
socket.create_connection((ip_addr, 3307)),
33+
server_hostname=ip_addr,
34+
)
35+
with patch("psycopg.connect") as mock_connect:
36+
type(mock_connect.return_value).autocommit = PropertyMock(return_value=True)
37+
connection = connect(ip_addr, sock, **kwargs)
38+
assert connection.autocommit is True
39+
# verify that driver connection call would be made
40+
assert mock_connect.assert_called_once

0 commit comments

Comments
 (0)