Skip to content

Commit 4d6a958

Browse files
feat: add support for specifying separate credentials for authenticating with DB (#510)
This PR does the following: - When authenticating with the DB, the alloydb.login scope is used instead of the cloud-platform scope. Authentication with the AlloyDB Admin API continues to use the cloud-platform scope. - The user can specify separate credentials, with the db_credentials argument passed into the Connector() and AsyncConnector(). When specified, only the db_credentials will be used to authenticate with the DB. If not specified, the existing behavior is preserved with credentials being used to authenticate with the DB.
1 parent 8f7bae7 commit 4d6a958

File tree

5 files changed

+232
-12
lines changed

5 files changed

+232
-12
lines changed

google/cloud/alloydbconnector/async_connector.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ class AsyncConnector:
4848
credentials (google.auth.credentials.Credentials):
4949
A credentials object created from the google-auth Python library.
5050
If not specified, Application Default Credentials are used.
51+
These are the credentials used for authenticating with the AlloyDB
52+
Admin API.
53+
db_credentials (google.auth.credentials.Credentials):
54+
A credentials object created from the google-auth Python library.
55+
This is only used when Auto IAM AuthN is enabled.
56+
If not specified, the credentials used for authenticating with the
57+
AlloyDB Admin API will also be used to authenticate with the DB.
58+
If specified, the credential's scope should be
59+
"https://www.googleapis.com/auth/alloydb.login".
5160
quota_project (str): The Project ID for an existing Google Cloud
5261
project. The project specified is used for quota and
5362
billing purposes.
@@ -67,6 +76,7 @@ class AsyncConnector:
6776
def __init__(
6877
self,
6978
credentials: Optional[Credentials] = None,
79+
db_credentials: Optional[Credentials] = None,
7080
quota_project: Optional[str] = None,
7181
alloydb_api_endpoint: str = "alloydb.googleapis.com",
7282
enable_iam_auth: bool = False,
@@ -88,13 +98,23 @@ def __init__(
8898
refresh_strategy = RefreshStrategy(refresh_strategy.upper())
8999
self._refresh_strategy = refresh_strategy
90100
self._user_agent = user_agent
91-
# initialize credentials
101+
# initialize credentials for authenticating with AlloyDB Admin API
92102
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
93103
if credentials:
94104
self._credentials = with_scopes_if_required(credentials, scopes=scopes)
95105
# otherwise use application default credentials
96106
else:
97107
self._credentials, _ = google.auth.default(scopes=scopes)
108+
# initialize credentials for authenticating with the DB
109+
if db_credentials:
110+
self._db_credentials = db_credentials
111+
# otherwise use the same credentials as the one for authenticating with
112+
# AlloyDB Admin API
113+
else:
114+
scopes = ["https://www.googleapis.com/auth/alloydb.login"]
115+
self._db_credentials = with_scopes_if_required(
116+
self._credentials, scopes=scopes
117+
)
98118

99119
# check if AsyncConnector is being initialized with event loop running
100120
# Otherwise we will lazy init keys
@@ -196,13 +216,13 @@ async def connect(
196216
logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433")
197217

198218
# callable to be used for auto IAM authn
199-
def get_authentication_token() -> str:
219+
async def get_authentication_token() -> str:
200220
"""Get OAuth2 access token to be used for IAM database authentication"""
201221
# refresh credentials if expired
202-
if not self._credentials.valid:
222+
if not self._db_credentials.valid:
203223
request = google.auth.transport.requests.Request()
204-
self._credentials.refresh(request)
205-
return self._credentials.token
224+
await asyncio.to_thread(self._db_credentials.refresh, request)
225+
return self._db_credentials.token
206226

207227
# if enable_iam_auth is set, use auth token as database password
208228
if enable_iam_auth:

google/cloud/alloydbconnector/connector.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ class Connector:
6262
credentials (google.auth.credentials.Credentials):
6363
A credentials object created from the google-auth Python library.
6464
If not specified, Application Default Credentials are used.
65+
These are the credentials used for authenticating with the AlloyDB
66+
Admin API.
67+
db_credentials (google.auth.credentials.Credentials):
68+
A credentials object created from the google-auth Python library.
69+
If not specified, the credentials used for authenticating with the
70+
AlloyDB Admin API will also be used to authenticate with the DB.
71+
If specified, the credential's scope should be
72+
"https://www.googleapis.com/auth/alloydb.login".
6573
quota_project (str): The Project ID for an existing Google Cloud
6674
project. The project specified is used for quota and
6775
billing purposes.
@@ -87,6 +95,7 @@ class Connector:
8795
def __init__(
8896
self,
8997
credentials: Optional[Credentials] = None,
98+
db_credentials: Optional[Credentials] = None,
9099
quota_project: Optional[str] = None,
91100
alloydb_api_endpoint: str = "alloydb.googleapis.com",
92101
enable_iam_auth: bool = False,
@@ -113,13 +122,23 @@ def __init__(
113122
refresh_strategy = RefreshStrategy(refresh_strategy.upper())
114123
self._refresh_strategy = refresh_strategy
115124
self._user_agent = user_agent
116-
# initialize credentials
125+
# initialize credentials for authenticating with AlloyDB Admin API
117126
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
118127
if credentials:
119128
self._credentials = with_scopes_if_required(credentials, scopes=scopes)
120129
# otherwise use application default credentials
121130
else:
122131
self._credentials, _ = default(scopes=scopes)
132+
# initialize credentials for authenticating with the DB
133+
if db_credentials:
134+
self._db_credentials = db_credentials
135+
# otherwise use the same credentials as the one for authenticating with
136+
# AlloyDB Admin API
137+
else:
138+
scopes = ["https://www.googleapis.com/auth/alloydb.login"]
139+
self._db_credentials = with_scopes_if_required(
140+
self._credentials, scopes=scopes
141+
)
123142
self._keys = asyncio.wrap_future(
124143
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
125144
loop=self._loop,
@@ -296,14 +315,14 @@ def metadata_exchange(
296315
auth_type = connectorspb.MetadataExchangeRequest.AUTO_IAM
297316

298317
# Ensure the credentials are in fact valid before proceeding.
299-
if not self._credentials.token_state == TokenState.FRESH:
300-
self._credentials.refresh(requests.Request())
318+
if not self._db_credentials.token_state == TokenState.FRESH:
319+
self._db_credentials.refresh(requests.Request())
301320

302321
# form metadata exchange request
303322
req = connectorspb.MetadataExchangeRequest(
304323
user_agent=f"{self._client._user_agent}", # type: ignore
305324
auth_type=auth_type,
306-
oauth2_token=self._credentials.token,
325+
oauth2_token=self._db_credentials.token,
307326
)
308327

309328
# set I/O timeout

tests/unit/mocks.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@
2121
import json
2222
import ssl
2323
import struct
24-
from typing import Any, Callable, Literal, Optional
24+
from typing import Any, Callable, Literal, Optional, Sequence
2525

2626
from cryptography import x509
2727
from cryptography.hazmat.primitives import hashes
2828
from cryptography.hazmat.primitives import serialization
2929
from cryptography.hazmat.primitives.asymmetric import rsa
3030
from cryptography.x509.oid import NameOID
3131
from google.auth.credentials import _helpers
32+
from google.auth.credentials import Scoped
3233
from google.auth.credentials import TokenState
3334
from google.auth.transport import requests
3435

@@ -57,6 +58,11 @@ def expired(self) -> bool:
5758
"""
5859
return False if not self.expiry else True
5960

61+
@property
62+
def valid(self) -> bool:
63+
"""Checks if the credentials are valid."""
64+
return self.token is not None and not self.expired
65+
6066
@property
6167
def token_state(
6268
self,
@@ -87,6 +93,26 @@ def token_state(
8793
return TokenState.FRESH
8894

8995

96+
class FakeCredentialsRequiresScopes(Scoped):
97+
def requires_scopes(self) -> bool:
98+
"""
99+
Overrides the requires_scopes() method of the Scoped class to require
100+
scopes for these credentials.
101+
"""
102+
return True
103+
104+
def with_scopes(
105+
self, scopes: Sequence[str], default_scopes: Optional[Sequence[str]] = None
106+
) -> "FakeCredentialsRequiresScopes":
107+
"""
108+
Overrides the with_scopes() method of the Scoped class to create a
109+
copy of these credentials with the specified scopes.
110+
"""
111+
f = FakeCredentialsRequiresScopes()
112+
f._scopes = scopes
113+
return f
114+
115+
90116
def generate_cert(
91117
common_name: str, expires_in: int = 60, server_cert: bool = False
92118
) -> tuple[x509.CertificateBuilder, rsa.RSAPrivateKey]:

tests/unit/test_async_connector.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# limitations under the License.
1414

1515
import asyncio
16-
from typing import Union
16+
from typing import Any, Union
1717

1818
from google.api_core.exceptions import RetryError
1919
from google.api_core.retry.retry_unary_async import AsyncRetry
2020
from mock import patch
2121
from mocks import FakeAlloyDBClient
2222
from mocks import FakeConnectionInfo
2323
from mocks import FakeCredentials
24+
from mocks import FakeCredentialsRequiresScopes
2425
import pytest
2526

2627
from google.cloud.alloydbconnector import AsyncConnector
@@ -44,11 +45,42 @@ async def test_AsyncConnector_init(credentials: FakeCredentials) -> None:
4445
assert connector._alloydb_api_endpoint == ALLOYDB_API_ENDPOINT
4546
assert connector._client is None
4647
assert connector._credentials == credentials
48+
assert connector._db_credentials == credentials
4749
assert connector._enable_iam_auth is False
4850
assert connector._closed is False
4951
await connector.close()
5052

5153

54+
@pytest.mark.asyncio
55+
async def test_AsyncConnector_init_db_credentials(credentials: FakeCredentials) -> None:
56+
"""
57+
Test to check whether the __init__ method of AsyncConnector
58+
properly sets db_credentials when specified.
59+
"""
60+
db_credentials = FakeCredentials()
61+
connector = AsyncConnector(credentials, db_credentials)
62+
assert connector._db_credentials == db_credentials
63+
await connector.close()
64+
65+
66+
async def test_AsyncConnector_init_scopes() -> None:
67+
"""
68+
Test to check whether the __init__ method of AsyncConnector
69+
properly sets the credential's scopes.
70+
"""
71+
credentials = FakeCredentialsRequiresScopes()
72+
connector = AsyncConnector(credentials)
73+
assert connector._credentials != credentials
74+
assert connector._credentials._scopes == [
75+
"https://www.googleapis.com/auth/cloud-platform"
76+
]
77+
assert connector._db_credentials != credentials
78+
assert connector._db_credentials._scopes == [
79+
"https://www.googleapis.com/auth/alloydb.login"
80+
]
81+
await connector.close()
82+
83+
5284
@pytest.mark.parametrize(
5385
"ip_type, expected",
5486
[
@@ -154,6 +186,7 @@ async def test_AsyncConnector_context_manager(
154186
assert connector._alloydb_api_endpoint == ALLOYDB_API_ENDPOINT
155187
assert connector._client is None
156188
assert connector._credentials == credentials
189+
assert connector._db_credentials == credentials
157190
assert connector._enable_iam_auth is False
158191

159192

@@ -197,6 +230,66 @@ async def test_connect_and_close(credentials: FakeCredentials) -> None:
197230
assert connection.result() is True
198231

199232

233+
@pytest.mark.asyncio
234+
async def test_connect_iam_authn(credentials: FakeCredentials) -> None:
235+
"""
236+
Test that connector.connect, with IAM authentication, refreshes credentials.
237+
"""
238+
async with AsyncConnector(credentials, enable_iam_auth=True) as connector:
239+
connector._client = FakeAlloyDBClient()
240+
241+
async def custom_connect(*_: Any, **kwargs: Any) -> bool:
242+
passwd = kwargs.pop("password")
243+
await passwd()
244+
245+
# patch db connection creation
246+
with patch(
247+
"google.cloud.alloydbconnector.asyncpg.connect", side_effect=custom_connect
248+
):
249+
await connector.connect(
250+
TEST_INSTANCE_NAME,
251+
"asyncpg",
252+
user="test-user",
253+
password="test-password",
254+
db="test-db",
255+
)
256+
# check DB authentication refreshed the credentials
257+
assert connector._credentials.token
258+
assert connector._db_credentials.token
259+
260+
261+
@pytest.mark.asyncio
262+
async def test_connect_db_credentials_iam_authn(credentials: FakeCredentials) -> None:
263+
"""
264+
Test that connector.connect, with IAM authentication, refreshes only the DB
265+
credentials when specified.
266+
"""
267+
db_credentials = FakeCredentials()
268+
async with AsyncConnector(
269+
credentials, db_credentials, enable_iam_auth=True
270+
) as connector:
271+
connector._client = FakeAlloyDBClient()
272+
273+
async def custom_connect(*_: Any, **kwargs: Any) -> bool:
274+
passwd = kwargs.pop("password")
275+
await passwd()
276+
277+
# patch db connection creation
278+
with patch(
279+
"google.cloud.alloydbconnector.asyncpg.connect", side_effect=custom_connect
280+
):
281+
await connector.connect(
282+
TEST_INSTANCE_NAME,
283+
"asyncpg",
284+
user="test-user",
285+
password="test-password",
286+
db="test-db",
287+
)
288+
# check DB authentication refreshed only the DB credential's token
289+
assert not connector._credentials.token
290+
assert connector._db_credentials.token
291+
292+
200293
@pytest.mark.asyncio
201294
async def test_force_refresh(credentials: FakeCredentials) -> None:
202295
"""

0 commit comments

Comments
 (0)