Skip to content

Commit d461d75

Browse files
feat: downscope token used for IAM DB AuthN (#488)
1 parent 5f1ab02 commit d461d75

File tree

5 files changed

+82
-14
lines changed

5 files changed

+82
-14
lines changed

google/cloud/sql/connector/instance.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,7 @@ def _auth_init(self, credentials: Optional[Credentials]) -> None:
261261
Credentials object used to authenticate connections to Cloud SQL server.
262262
If not specified, Application Default Credentials are used.
263263
"""
264-
scopes = [
265-
"https://www.googleapis.com/auth/sqlservice.admin",
266-
"https://www.googleapis.com/auth/cloud-platform",
267-
]
264+
scopes = ["https://www.googleapis.com/auth/sqlservice.admin"]
268265
# if Credentials object is passed in, use for authentication
269266
if isinstance(credentials, Credentials):
270267
credentials = with_scopes_if_required(credentials, scopes=scopes)

google/cloud/sql/connector/refresh_utils.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717

1818
import aiohttp
1919
import google.auth
20-
from google.auth.credentials import Credentials
20+
from google.auth.credentials import Credentials, Scoped
2121
import google.auth.transport.requests
22-
from typing import Any, Dict
22+
from typing import Any, Dict, List
2323
import datetime
24+
import copy
2425
import asyncio
2526
import logging
2627

@@ -166,7 +167,7 @@ async def _get_ephemeral(
166167
elif not isinstance(pub_key, str):
167168
raise TypeError(f"pub_key must be of type str, got {type(pub_key)}")
168169

169-
if not credentials.valid or enable_iam_auth:
170+
if not credentials.valid:
170171
request = google.auth.transport.requests.Request()
171172
credentials.refresh(request)
172173

@@ -181,7 +182,9 @@ async def _get_ephemeral(
181182
data = {"public_key": pub_key}
182183

183184
if enable_iam_auth:
184-
data["access_token"] = credentials.token
185+
# down-scope credentials with only IAM login scope (refreshes them too)
186+
login_creds = _downscope_credentials(credentials)
187+
data["access_token"] = login_creds.token
185188

186189
resp = await client_session.post(
187190
url, headers=headers, json=data, raise_for_status=True
@@ -229,3 +232,38 @@ async def _is_valid(task: asyncio.Task) -> bool:
229232
# supress any errors from task
230233
logger.debug("Current instance metadata is invalid.")
231234
return False
235+
236+
237+
def _downscope_credentials(
238+
credentials: Credentials,
239+
scopes: List[str] = ["https://www.googleapis.com/auth/sqlservice.login"],
240+
) -> Credentials:
241+
"""Generate a down-scoped credential.
242+
243+
:type credentials: google.auth.credentials.Credentials
244+
:param credentials
245+
Credentials object used to generate down-scoped credentials.
246+
247+
:type scopes: List[str]
248+
:param scopes
249+
List of Google scopes to include in down-scoped credentials object.
250+
251+
:rtype: google.auth.credentials.Credentials
252+
:returns: Down-scoped credentials object.
253+
"""
254+
# credentials sourced from a service account or metadata are children of
255+
# Scoped class and are capable of being re-scoped
256+
if isinstance(credentials, Scoped):
257+
scoped_creds = credentials.with_scopes(scopes=scopes)
258+
# authenticated user credentials can not be re-scoped
259+
else:
260+
# create shallow copy to not overwrite scopes on default credentials
261+
scoped_creds = copy.copy(credentials)
262+
# overwrite '_scopes' to down-scope user credentials
263+
# Cloud SDK reference: https://github.com/google-cloud-sdk-unofficial/google-cloud-sdk/blob/93920ccb6d2cce0fe6d1ce841e9e33410551d66b/lib/googlecloudsdk/command_lib/sql/generate_login_token_util.py#L116
264+
scoped_creds._scopes = scopes
265+
# down-scoped credentials require refresh, are invalid after being re-scoped
266+
if not scoped_creds.valid:
267+
request = google.auth.transport.requests.Request()
268+
scoped_creds.refresh(request)
269+
return scoped_creds

tests/conftest.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@
3030
from google.cloud.sql.connector.instance import Instance
3131
from google.cloud.sql.connector.utils import generate_keys
3232

33-
SCOPES = [
34-
"https://www.googleapis.com/auth/sqlservice.admin",
35-
"https://www.googleapis.com/auth/cloud-platform",
36-
]
33+
SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin"]
3734

3835

3936
def pytest_addoption(parser: Any) -> None:

tests/unit/test_instance.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,9 @@ async def test_perform_refresh_expiration(
294294
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=1)
295295
setattr(instance._credentials, "expiry", expiration)
296296
setattr(instance, "_enable_iam_auth", True)
297-
instance_metadata = await instance._perform_refresh()
297+
# set all credentials to valid so downscoped credential does not refresh
298+
with patch.object(Credentials, "valid", True):
299+
instance_metadata = await instance._perform_refresh()
298300

299301
# verify instance metadata object is returned
300302
assert isinstance(instance_metadata, InstanceMetadata)

tests/unit/test_refresh_utils.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@
1717

1818
import aiohttp
1919
from google.auth.credentials import Credentials
20+
import google.oauth2.credentials
2021
import pytest # noqa F401 Needed to run the tests
21-
from mock import Mock
22+
from mock import Mock, patch
2223
from aioresponses import aioresponses
2324
import asyncio
2425

2526
from google.cloud.sql.connector.refresh_utils import (
2627
_get_ephemeral,
2728
_get_metadata,
2829
_is_valid,
30+
_downscope_credentials,
2931
)
3032
from google.cloud.sql.connector.utils import generate_keys
3133

@@ -35,6 +37,7 @@
3537
instance_metadata_expired,
3638
FakeCSQLInstance,
3739
)
40+
from tests.conftest import SCOPES # type: ignore
3841

3942

4043
@pytest.fixture
@@ -231,3 +234,34 @@ async def test_is_valid_with_expired_metadata() -> None:
231234
# task that returns class with expiration 10 mins in past
232235
task = asyncio.create_task(instance_metadata_expired())
233236
assert not await _is_valid(task)
237+
238+
239+
def test_downscope_credentials_service_account(fake_credentials: Credentials) -> None:
240+
"""
241+
Test _downscope_credentials with google.oauth2.service_account.Credentials
242+
which mimics an authenticated service account.
243+
"""
244+
# set all credentials to valid to skip refreshing credentials
245+
with patch.object(Credentials, "valid", True):
246+
credentials = _downscope_credentials(fake_credentials)
247+
# verify default credential scopes have not been altered
248+
assert fake_credentials.scopes == SCOPES
249+
# verify downscoped credentials have new scope
250+
assert credentials.scopes == ["https://www.googleapis.com/auth/sqlservice.login"]
251+
assert credentials != fake_credentials
252+
253+
254+
def test_downscope_credentials_user() -> None:
255+
"""
256+
Test _downscope_credentials with google.oauth2.credentials.Credentials
257+
which mimics an authenticated user.
258+
"""
259+
creds = google.oauth2.credentials.Credentials("token", scopes=SCOPES)
260+
# set all credentials to valid to skip refreshing credentials
261+
with patch.object(Credentials, "valid", True):
262+
credentials = _downscope_credentials(creds)
263+
# verify default credential scopes have not been altered
264+
assert creds.scopes == SCOPES
265+
# verify downscoped credentials have new scope
266+
assert credentials.scopes == ["https://www.googleapis.com/auth/sqlservice.login"]
267+
assert credentials != creds

0 commit comments

Comments
 (0)