Skip to content

Commit 6703232

Browse files
feat: add support for IAM auth with pg8000 driver (#101)
* feat: add support for IAM auth with pg8000 driver * Update google/cloud/sql/connector/instance_connection_manager.py Co-authored-by: Kurtis Van Gent <[email protected]> * Apply suggestions from code review Co-authored-by: Kurtis Van Gent <[email protected]> * calculate seconds to refresh before returning metadata object * address review comments Co-authored-by: Kurtis Van Gent <[email protected]>
1 parent 3d21a8b commit 6703232

File tree

9 files changed

+171
-22
lines changed

9 files changed

+171
-22
lines changed

.mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@ ignore_missing_imports = True
1313
ignore_missing_imports = True
1414

1515
[mypy-pytest]
16+
ignore_missing_imports = True
17+
18+
[mypy-OpenSSL]
1619
ignore_missing_imports = True

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ Note: If specifying Private IP, your application must already be in the same VPC
8282
Tests can be run with `nox`. Change directory into the `cloud-sql-python-connector` and just run `nox` to run the tests.
8383

8484
1. Create a MySQL instance on Google Cloud SQL. Make sure to note your root password when creating the MySQL instance.
85-
2. When the MySQL instance has finished creating, go to the overview page and set the instance’s connection string to the environment variable INSTANCE_CONNECTION_NAME using the following command:
85+
2. When the MySQL instance has finished creating, go to the overview page and set the instance’s connection string to the environment variable MYSQL_CONNECTION_NAME using the following command:
8686
```
87-
export INSTANCE_CONNECTION_NAME=your:connection:string
87+
export MYSQL_CONNECTION_NAME=your:connection:string
8888
```
8989
3. Enable SSL for your Cloud SQL instance by following [these instructions](https://cloud.google.com/sql/docs/mysql/configure-ssl-instance).
9090
4. Create a service account with Cloud SQL Admin and Cloud SQL Client roles, then download the key and save it in a safe location. Set the path to the json file to the environment variable GOOGLE_APPLICATION_CREDENTIALS using the following command:

google/cloud/sql/connector/connector.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def connect(
5252
instance_connection_string: str,
5353
driver: str,
5454
ip_types: IPTypes = IPTypes.PUBLIC,
55+
enable_iam_auth: bool = False,
5556
**kwargs: Any
5657
) -> Any:
5758
"""Prepares and returns a database connection object and starts a
@@ -73,6 +74,10 @@ def connect(
7374
The IP type (public or private) used to connect. IP types
7475
can be either IPTypes.PUBLIC or IPTypes.PRIVATE.
7576
77+
:param enable_iam_auth
78+
Enables IAM based authentication for Postgres instances.
79+
:type enable_iam_auth: bool
80+
7681
:param kwargs:
7782
Pass in any driver-specific arguments needed to connect to the Cloud
7883
SQL instance.
@@ -96,7 +101,9 @@ def connect(
96101
icm = _instances[instance_connection_string]
97102
else:
98103
keys = _get_keys(loop)
99-
icm = InstanceConnectionManager(instance_connection_string, driver, keys, loop)
104+
icm = InstanceConnectionManager(
105+
instance_connection_string, driver, keys, loop, enable_iam_auth
106+
)
100107
_instances[instance_connection_string] = icm
101108

102109
if "timeout" in kwargs:

google/cloud/sql/connector/instance_connection_manager.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
import asyncio
2424
import aiohttp
2525
import concurrent
26+
import datetime
2627
from enum import Enum
2728
import google.auth
2829
from google.auth.credentials import Credentials
2930
import google.auth.transport.requests
31+
import OpenSSL
3032
import ssl
3133
import socket
3234
from tempfile import TemporaryDirectory
@@ -52,9 +54,15 @@
5254
APPLICATION_NAME = "cloud-sql-python-connector"
5355
SERVER_PROXY_PORT = 3307
5456

55-
# The default delay is set to 55 minutes since each ephemeral certificate is only
56-
# valid for an hour. This gives five minutes of buffer time.
57-
_delay: int = 55 * 60
57+
# default_refresh_buffer is the amount of time before a refresh's result expires
58+
# that a new refresh operation begins.
59+
_default_refresh_buffer: int = 5 * 60 # 5 minutes
60+
61+
# _iam_auth_refresh_buffer is the amount of time before a refresh's result expires
62+
# that a new refresh operation begins when IAM DB AuthN is enabled. Because token
63+
# sources may be cached until ~60 seconds before expiration, this value must be smaller
64+
# than default_refresh_buffer.
65+
_iam_auth_refresh_buffer: int = 55 # seconds
5866

5967

6068
class IPTypes(Enum):
@@ -94,16 +102,19 @@ def __init__(self, *args: Any) -> None:
94102
class InstanceMetadata:
95103
ip_addrs: Dict[str, Any]
96104
context: ssl.SSLContext
105+
expiration: datetime.datetime
97106

98107
def __init__(
99108
self,
100109
ephemeral_cert: str,
101110
ip_addrs: Dict[str, Any],
102111
private_key: bytes,
103112
server_ca_cert: str,
113+
expiration: datetime.datetime,
104114
) -> None:
105115
self.ip_addrs = ip_addrs
106116
self.context = ConnectionSSLContext()
117+
self.expiration = expiration
107118

108119
# tmpdir and its contents are automatically deleted after the CA cert
109120
# and ephemeral cert are loaded into the SSLcontext. The values
@@ -140,6 +151,10 @@ class InstanceConnectionManager:
140151
The user agent string to append to SQLAdmin API requests
141152
:type user_agent_string: str
142153
154+
:param enable_iam_auth
155+
Enables IAM based authentication for Postgres instances.
156+
:type enable_iam_auth: bool
157+
143158
:param loop:
144159
A new event loop for the refresh function to run in.
145160
:type loop: asyncio.AbstractEventLoop
@@ -153,6 +168,8 @@ class InstanceConnectionManager:
153168
# https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/22
154169
_loop: asyncio.AbstractEventLoop
155170

171+
_enable_iam_auth: bool
172+
156173
__client_session: Optional[aiohttp.ClientSession] = None
157174

158175
@property
@@ -185,6 +202,7 @@ def __init__(
185202
driver_name: str,
186203
keys: concurrent.futures.Future,
187204
loop: asyncio.AbstractEventLoop,
205+
enable_iam_auth: bool = False,
188206
) -> None:
189207
# Validate connection string
190208
connection_string_split = instance_connection_string.split(":")
@@ -200,6 +218,8 @@ def __init__(
200218
+ "format: project:region:instance."
201219
)
202220

221+
self._enable_iam_auth = enable_iam_auth
222+
203223
self._user_agent_string = f"{APPLICATION_NAME}/{version}+{driver_name}"
204224
self._loop = loop
205225
self._keys = asyncio.wrap_future(keys, loop=self._loop)
@@ -261,16 +281,30 @@ async def _get_instance_data(self) -> InstanceMetadata:
261281
self._project,
262282
self._instance,
263283
pub_key,
284+
self._enable_iam_auth,
264285
)
265286
)
266287

267288
metadata, ephemeral_cert = await asyncio.gather(metadata_task, ephemeral_task)
268289

290+
x509 = OpenSSL.crypto.load_certificate(
291+
OpenSSL.crypto.FILETYPE_PEM, ephemeral_cert
292+
)
293+
expiration = datetime.datetime.strptime(
294+
x509.get_notAfter().decode("ascii"), "%Y%m%d%H%M%SZ"
295+
)
296+
if self._credentials is not None:
297+
token_expiration: datetime.datetime = self._credentials.expiry
298+
299+
if expiration > token_expiration:
300+
expiration = token_expiration
301+
269302
return InstanceMetadata(
270303
ephemeral_cert,
271304
metadata["ip_addresses"],
272305
priv_key,
273306
metadata["server_ca_cert"],
307+
expiration,
274308
)
275309

276310
def _auth_init(self) -> None:
@@ -287,6 +321,27 @@ def _auth_init(self) -> None:
287321

288322
self._credentials = credentials
289323

324+
async def seconds_until_refresh(self) -> int:
325+
expiration = (await self._current).expiration
326+
327+
if self._enable_iam_auth:
328+
refresh_buffer = _iam_auth_refresh_buffer
329+
else:
330+
refresh_buffer = _default_refresh_buffer
331+
332+
delay = (expiration - datetime.datetime.now()) - datetime.timedelta(
333+
seconds=refresh_buffer
334+
)
335+
336+
if delay.total_seconds() < 0:
337+
# If the time until the certificate expires is less than the buffer,
338+
# schedule the refresh closer to the expiration time
339+
delay = (expiration - datetime.datetime.now()) - datetime.timedelta(
340+
seconds=5
341+
)
342+
343+
return int(delay.total_seconds())
344+
290345
async def _perform_refresh(self) -> asyncio.Task:
291346
"""Retrieves instance metadata and ephemeral certificate from the
292347
Cloud SQL Instance.
@@ -299,22 +354,23 @@ async def _perform_refresh(self) -> asyncio.Task:
299354

300355
self._current = self._loop.create_task(self._get_instance_data())
301356
# Ephemeral certificate expires in 1 hour, so we schedule a refresh to happen in 55 minutes.
302-
self._next = self._loop.create_task(self._schedule_refresh(_delay))
357+
358+
self._next = self._loop.create_task(self._schedule_refresh())
303359

304360
return self._current
305361

306-
async def _schedule_refresh(self, delay: int) -> asyncio.Task:
362+
async def _schedule_refresh(self, delay: Optional[int] = None) -> asyncio.Task:
307363
"""A coroutine that sleeps for the specified amount of time before
308364
running _perform_refresh.
309365
310-
:type delay: int
311-
:param delay: An integer representing the number of seconds for delay.
312-
313366
:rtype: asyncio.Task
314367
:returns: A Task representing _get_instance_data.
315368
"""
316369
logger.debug("Entering sleep")
317370

371+
if delay is None:
372+
delay = await self.seconds_until_refresh()
373+
318374
try:
319375
await asyncio.sleep(delay)
320376
except asyncio.CancelledError as e:
@@ -454,7 +510,7 @@ def _connect_with_pg8000(
454510
)
455511
user = kwargs.pop("user")
456512
db = kwargs.pop("db")
457-
passwd = kwargs.pop("password")
513+
passwd = kwargs.pop("password", None)
458514
setattr(ctx, "request_ssl", False)
459515
return pg8000.dbapi.connect(
460516
user,

google/cloud/sql/connector/refresh_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ async def _get_ephemeral(
102102
project: str,
103103
instance: str,
104104
pub_key: str,
105+
enable_iam_auth: bool = False,
105106
) -> str:
106107
"""Asynchronously requests an ephemeral certificate from the Cloud SQL Instance.
107108
@@ -120,6 +121,10 @@ async def _get_ephemeral(
120121
:type pub_key:
121122
:param str: A string representing PEM-encoded RSA public key.
122123
124+
:type enable_iam_auth: bool
125+
:param enable_iam_auth
126+
Enables IAM based authentication for Postgres instances.
127+
123128
:rtype: str
124129
:returns: An ephemeral certificate from the Cloud SQL instance that allows
125130
authorized connections to the instance.
@@ -141,7 +146,7 @@ async def _get_ephemeral(
141146
elif not isinstance(pub_key, str):
142147
raise TypeError(f"pub_key must be of type str, got {type(pub_key)}")
143148

144-
if not credentials.valid:
149+
if not credentials.valid or enable_iam_auth:
145150
request = google.auth.transport.requests.Request()
146151
credentials.refresh(request)
147152

@@ -155,6 +160,9 @@ async def _get_ephemeral(
155160

156161
data = {"public_key": pub_key}
157162

163+
if enable_iam_auth:
164+
data["access_token"] = credentials.token
165+
158166
resp = await client_session.post(
159167
url, headers=headers, json=data, raise_for_status=True
160168
)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def connect_string() -> str:
6565
returns it.
6666
"""
6767
try:
68-
connect_string = os.environ["INSTANCE_CONNECTION_NAME"]
68+
connect_string = os.environ["MYSQL_CONNECTION_NAME"]
6969
except KeyError:
7070
raise KeyError(
7171
"Please set environment variable 'INSTANCE_CONNECTION"
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
""""
2+
Copyright 2021 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
import os
17+
import uuid
18+
from typing import Generator
19+
20+
import pg8000
21+
import pytest
22+
import sqlalchemy
23+
from google.cloud.sql.connector import connector
24+
25+
table_name = f"books_{uuid.uuid4().hex}"
26+
27+
28+
def init_connection_engine() -> sqlalchemy.engine.Engine:
29+
def getconn() -> pg8000.dbapi.Connection:
30+
conn: pg8000.dbapi.Connection = connector.connect(
31+
os.environ["POSTGRES_IAM_CONNECTION_NAME"],
32+
"pg8000",
33+
user=os.environ["POSTGRES_IAM_USER"],
34+
db=os.environ["POSTGRES_DB"],
35+
enable_iam_auth=True,
36+
)
37+
return conn
38+
39+
engine = sqlalchemy.create_engine(
40+
"postgresql+pg8000://",
41+
creator=getconn,
42+
)
43+
engine.dialect.description_encoding = None
44+
return engine
45+
46+
47+
@pytest.fixture(name="pool")
48+
def setup() -> Generator:
49+
pool = init_connection_engine()
50+
51+
with pool.connect() as conn:
52+
conn.execute(
53+
f"CREATE TABLE IF NOT EXISTS {table_name}"
54+
" ( id CHAR(20) NOT NULL, title TEXT NOT NULL );"
55+
)
56+
57+
yield pool
58+
59+
with pool.connect() as conn:
60+
conn.execute(f"DROP TABLE IF EXISTS {table_name}")
61+
62+
63+
def test_pooled_connection_with_pg8000_iam_auth(pool: sqlalchemy.engine.Engine) -> None:
64+
insert_stmt = sqlalchemy.text(
65+
f"INSERT INTO {table_name} (id, title) VALUES (:id, :title)",
66+
)
67+
with pool.connect() as conn:
68+
conn.execute(insert_stmt, id="book1", title="Book One")
69+
conn.execute(insert_stmt, id="book2", title="Book Two")
70+
71+
select_stmt = sqlalchemy.text(f"SELECT title FROM {table_name} ORDER BY ID;")
72+
with pool.connect() as conn:
73+
rows = conn.execute(select_stmt).fetchall()
74+
titles = [row[0] for row in rows]
75+
76+
assert titles == ["Book One", "Book Two"]

tests/unit/test_instance_connection_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ def test_InstanceConnectionManager_init(async_loop: asyncio.AbstractEventLoop) -
5454

5555
@pytest.mark.asyncio
5656
async def test_InstanceConnectionManager_perform_refresh(
57-
icm: InstanceConnectionManager,
57+
icm: InstanceConnectionManager, async_loop: asyncio.AbstractEventLoop
5858
) -> None:
5959
"""
60-
Test to check whether _get_perform works as described given valid
60+
Test to check whether _perform_refresh works as described given valid
6161
conditions.
6262
"""
6363
task = await icm._perform_refresh()

0 commit comments

Comments
 (0)