Skip to content

Commit 85f5476

Browse files
feat: add arg for specifying credentials (#226)
1 parent 8359f85 commit 85f5476

File tree

3 files changed

+105
-13
lines changed

3 files changed

+105
-13
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
IPTypes,
2222
)
2323
from google.cloud.sql.connector.utils import generate_keys
24-
24+
from google.auth.credentials import Credentials
2525
from threading import Thread
26-
from typing import Any, Dict
26+
from typing import Any, Dict, Optional
2727

2828
logger = logging.getLogger(name=__name__)
2929

@@ -43,16 +43,21 @@ class Connector:
4343
Enables IAM based authentication (Postgres only).
4444
4545
:type timeout: int
46-
:param timeout:
46+
:param timeout
4747
The time limit for a connection before raising a TimeoutError.
4848
49+
:type credentials: google.auth.credentials.Credentials
50+
:param credentials
51+
Credentials object used to authenticate connections to Cloud SQL server.
52+
If not specified, Application Default Credentials are used.
4953
"""
5054

5155
def __init__(
5256
self,
5357
ip_types: IPTypes = IPTypes.PUBLIC,
5458
enable_iam_auth: bool = False,
5559
timeout: int = 30,
60+
credentials: Optional[Credentials] = None,
5661
) -> None:
5762
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
5863
self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True)
@@ -66,6 +71,7 @@ def __init__(
6671
self._timeout = timeout
6772
self._enable_iam_auth = enable_iam_auth
6873
self._ip_types = ip_types
74+
self._credentials = credentials
6975

7076
def connect(
7177
self, instance_connection_string: str, driver: str, **kwargs: Any
@@ -112,6 +118,7 @@ def connect(
112118
driver,
113119
self._keys,
114120
self._loop,
121+
self._credentials,
115122
enable_iam_auth,
116123
)
117124
self._instances[instance_connection_string] = icm

google/cloud/sql/connector/instance_connection_manager.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import datetime
2828
from enum import Enum
2929
import google.auth
30-
from google.auth.credentials import Credentials
30+
from google.auth.credentials import Credentials, with_scopes_if_required
3131
import google.auth.transport.requests
3232
import OpenSSL
3333
import platform
@@ -117,6 +117,15 @@ def __init__(self, *args: Any) -> None:
117117
super(PlatformNotSupportedError, self).__init__(self, *args)
118118

119119

120+
class CredentialsTypeError(Exception):
121+
"""
122+
Raised when credentials parameter is not proper type.
123+
"""
124+
125+
def __init__(self, *args: Any) -> None:
126+
super(CredentialsTypeError, self).__init__(self, *args)
127+
128+
120129
class InstanceMetadata:
121130
ip_addrs: Dict[str, Any]
122131
context: ssl.SSLContext
@@ -177,6 +186,11 @@ class InstanceConnectionManager:
177186
The user agent string to append to SQLAdmin API requests
178187
:type user_agent_string: str
179188
189+
:type credentials: google.auth.credentials.Credentials
190+
:param credentials
191+
Credentials object used to authenticate connections to Cloud SQL server.
192+
If not specified, Application Default Credentials are used.
193+
180194
:param enable_iam_auth
181195
Enables IAM based authentication for Postgres instances.
182196
:type enable_iam_auth: bool
@@ -229,6 +243,7 @@ def __init__(
229243
driver_name: str,
230244
keys: concurrent.futures.Future,
231245
loop: asyncio.AbstractEventLoop,
246+
credentials: Optional[Credentials] = None,
232247
enable_iam_auth: bool = False,
233248
) -> None:
234249
# Validate connection string
@@ -250,7 +265,14 @@ def __init__(
250265
self._user_agent_string = f"{APPLICATION_NAME}/{version}+{driver_name}"
251266
self._loop = loop
252267
self._keys = asyncio.wrap_future(keys, loop=self._loop)
253-
self._auth_init()
268+
# validate credentials type
269+
if not isinstance(credentials, Credentials) and credentials is not None:
270+
raise CredentialsTypeError(
271+
"Arg credentials must be type 'google.auth.credentials.Credentials' "
272+
"or None (to use Application Default Credentials)"
273+
)
274+
275+
self._auth_init(credentials)
254276

255277
self._refresh_rate_limiter = AsyncRateLimiter(
256278
max_capacity=2, rate=1 / 30, loop=self._loop
@@ -343,17 +365,25 @@ async def _get_instance_data(self) -> InstanceMetadata:
343365
self._enable_iam_auth,
344366
)
345367

346-
def _auth_init(self) -> None:
368+
def _auth_init(self, credentials: Optional[Credentials]) -> None:
347369
"""Creates and assigns a Google Python API service object for
348370
Google Cloud SQL Admin API.
349-
"""
350371
351-
credentials, project = google.auth.default(
352-
scopes=[
353-
"https://www.googleapis.com/auth/sqlservice.admin",
354-
"https://www.googleapis.com/auth/cloud-platform",
355-
]
356-
)
372+
:type credentials: google.auth.credentials.Credentials
373+
:param credentials
374+
Credentials object used to authenticate connections to Cloud SQL server.
375+
If not specified, Application Default Credentials are used.
376+
"""
377+
scopes = [
378+
"https://www.googleapis.com/auth/sqlservice.admin",
379+
"https://www.googleapis.com/auth/cloud-platform",
380+
]
381+
# if Credentials object is passed in, use for authentication
382+
if isinstance(credentials, Credentials):
383+
credentials = with_scopes_if_required(credentials, scopes=scopes)
384+
# otherwise use application default credentials
385+
else:
386+
credentials, project = google.auth.default(scopes=scopes)
357387

358388
self._credentials = credentials
359389

tests/unit/test_instance_connection_manager.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,24 @@
1515
"""
1616

1717
import asyncio
18+
from unittest.mock import Mock, patch
1819
import datetime
1920
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
2021
from typing import Any
2122
import pytest # noqa F401 Needed to run the tests
23+
from google.auth.credentials import Credentials
2224
from google.cloud.sql.connector.instance_connection_manager import (
2325
InstanceConnectionManager,
26+
CredentialsTypeError,
2427
)
2528
from google.cloud.sql.connector.utils import generate_keys
2629

2730

31+
@pytest.fixture
32+
def mock_credentials() -> Credentials:
33+
return Mock(spec=Credentials)
34+
35+
2836
@pytest.fixture
2937
def icm(
3038
async_loop: asyncio.AbstractEventLoop, connect_string: str
@@ -73,6 +81,21 @@ def test_InstanceConnectionManager_init(async_loop: asyncio.AbstractEventLoop) -
7381
)
7482

7583

84+
def test_InstanceConnectionManager_init_bad_credentials(
85+
async_loop: asyncio.AbstractEventLoop,
86+
) -> None:
87+
"""
88+
Test to check whether the __init__ method of InstanceConnectionManager
89+
throws proper error for bad credentials arg type.
90+
"""
91+
connect_string = "test-project:test-region:test-instance"
92+
keys = asyncio.run_coroutine_threadsafe(generate_keys(), async_loop)
93+
with pytest.raises(CredentialsTypeError):
94+
assert InstanceConnectionManager(
95+
connect_string, "pymysql", keys, async_loop, credentials=1
96+
)
97+
98+
7699
@pytest.mark.asyncio
77100
async def test_perform_refresh_replaces_result(
78101
icm: InstanceConnectionManager, test_rate_limiter: AsyncRateLimiter
@@ -171,3 +194,35 @@ async def test_force_refresh_cancels_pending_refresh(
171194

172195
assert pending_refresh.cancelled() is True
173196
assert isinstance(icm._current.result(), MockMetadata)
197+
198+
199+
def test_auth_init_with_credentials_object(
200+
icm: InstanceConnectionManager, mock_credentials: Credentials
201+
) -> None:
202+
"""
203+
Test that InstanceConnectionManager's _auth_init initializes _credentials
204+
when passed a google.auth.credentials.Credentials object.
205+
"""
206+
setattr(icm, "_credentials", None)
207+
with patch(
208+
"google.cloud.sql.connector.instance_connection_manager.with_scopes_if_required"
209+
) as mock_auth:
210+
mock_auth.return_value = mock_credentials
211+
icm._auth_init(credentials=mock_credentials)
212+
assert isinstance(icm._credentials, Credentials)
213+
mock_auth.assert_called_once()
214+
215+
216+
def test_auth_init_with_default_credentials(
217+
icm: InstanceConnectionManager, mock_credentials: Credentials
218+
) -> None:
219+
"""
220+
Test that InstanceConnectionManager's _auth_init initializes _credentials
221+
with application default credentials when credentials are not specified.
222+
"""
223+
setattr(icm, "_credentials", None)
224+
with patch("google.auth.default") as mock_auth:
225+
mock_auth.return_value = mock_credentials, None
226+
icm._auth_init(credentials=None)
227+
assert isinstance(icm._credentials, Credentials)
228+
mock_auth.assert_called_once()

0 commit comments

Comments
 (0)