diff --git a/google/cloud/alloydbconnector/async_connector.py b/google/cloud/alloydbconnector/async_connector.py index 7aac297e..0d6da30a 100644 --- a/google/cloud/alloydbconnector/async_connector.py +++ b/google/cloud/alloydbconnector/async_connector.py @@ -147,6 +147,19 @@ async def connect( enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + # callable to be used for auto IAM authn + def get_authentication_token() -> str: + """Get OAuth2 access token to be used for IAM database authentication""" + # refresh credentials if expired + if not self._credentials.valid: + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + return self._credentials.token + + # if enable_iam_auth is set, use auth token as database password + if enable_iam_auth: + kwargs["password"] = get_authentication_token + # use existing connection info if possible if instance_uri in self._cache: cache = self._cache[instance_uri] @@ -185,36 +198,27 @@ async def connect( # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes(ip_type.upper()) - try: - conn_info = await cache.connect_info() - ip_address = conn_info.get_preferred_ip(ip_type) - except Exception: - # with an error from AlloyDB API call or IP type, invalidate the - # cache and re-raise the error - await self._remove_cached(instance_uri) - raise - logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433") - - # callable to be used for auto IAM authn - def get_authentication_token() -> str: - """Get OAuth2 access token to be used for IAM database authentication""" - # refresh credentials if expired - if not self._credentials.valid: - request = google.auth.transport.requests.Request() - self._credentials.refresh(request) - return self._credentials.token - - # if enable_iam_auth is set, use auth token as database password - if enable_iam_auth: - kwargs["password"] = get_authentication_token - try: - return await connector( - ip_address, await conn_info.create_ssl_context(), **kwargs - ) - except Exception: - # we attempt a force refresh, then throw the error - await cache.force_refresh() - raise + for i in range(2): + try: + conn_info = await cache.connect_info() + ip_address = conn_info.get_preferred_ip(ip_type) + except Exception: + # with an error from AlloyDB API call or IP type, invalidate the + # cache and re-raise the error + await self._remove_cached(instance_uri) + raise + logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433") + + try: + return await connector( + ip_address, await conn_info.create_ssl_context(), **kwargs + ) + except Exception: + # we attempt a force refresh and retry the connection before + # throwing an error + if i == 1: + raise + await cache.force_refresh(True) async def _remove_cached(self, instance_uri: str) -> None: """Stops all background refreshes and deletes the connection diff --git a/google/cloud/alloydbconnector/connector.py b/google/cloud/alloydbconnector/connector.py index 0b100a82..b6c0eb84 100644 --- a/google/cloud/alloydbconnector/connector.py +++ b/google/cloud/alloydbconnector/connector.py @@ -226,31 +226,34 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes(ip_type.upper()) - try: - conn_info = await cache.connect_info() - ip_address = conn_info.get_preferred_ip(ip_type) - except Exception: - # with an error from AlloyDB API call or IP type, invalidate the - # cache and re-raise the error - await self._remove_cached(instance_uri) - raise - logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433") - - # synchronous drivers are blocking and run using executor - try: - metadata_partial = partial( - self.metadata_exchange, - ip_address, - await conn_info.create_ssl_context(), - enable_iam_auth, - ) - sock = await self._loop.run_in_executor(None, metadata_partial) - connect_partial = partial(connector, sock, **kwargs) - return await self._loop.run_in_executor(None, connect_partial) - except Exception: - # we attempt a force refresh, then throw the error - await cache.force_refresh() - raise + for i in range(2): + try: + conn_info = await cache.connect_info() + ip_address = conn_info.get_preferred_ip(ip_type) + except Exception: + # with an error from AlloyDB API call or IP type, invalidate the + # cache and re-raise the error + await self._remove_cached(instance_uri) + raise + logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433") + + # synchronous drivers are blocking and run using executor + try: + metadata_partial = partial( + self.metadata_exchange, + ip_address, + await conn_info.create_ssl_context(), + enable_iam_auth, + ) + sock = await self._loop.run_in_executor(None, metadata_partial) + connect_partial = partial(connector, sock, **kwargs) + return await self._loop.run_in_executor(None, connect_partial) + except Exception: + # we attempt a force refresh and retry the connection before + # throwing an error + if i == 1: + raise + await cache.force_refresh(True) def metadata_exchange( self, ip_address: str, ctx: ssl.SSLContext, enable_iam_auth: bool diff --git a/google/cloud/alloydbconnector/instance.py b/google/cloud/alloydbconnector/instance.py index 9c3051e2..b6626233 100644 --- a/google/cloud/alloydbconnector/instance.py +++ b/google/cloud/alloydbconnector/instance.py @@ -208,7 +208,7 @@ async def _refresh_operation(self, delay: int) -> ConnectionInfo: return refresh_result - async def force_refresh(self) -> None: + async def force_refresh(self, block: bool = False) -> None: """ Schedules a new refresh operation immediately to be used for future connection attempts. @@ -218,7 +218,7 @@ async def force_refresh(self) -> None: self._next.cancel() self._next = self._schedule_refresh(0) # block all sequential connection attempts on the next refresh result if current is invalid - if not await _is_valid(self._current): + if block or not await _is_valid(self._current): self._current = self._next async def connect_info(self) -> ConnectionInfo: diff --git a/google/cloud/alloydbconnector/lazy.py b/google/cloud/alloydbconnector/lazy.py index d86ea522..5bce7ca5 100644 --- a/google/cloud/alloydbconnector/lazy.py +++ b/google/cloud/alloydbconnector/lazy.py @@ -63,7 +63,7 @@ def __init__( self._cached: Optional[ConnectionInfo] = None self._needs_refresh = False - async def force_refresh(self) -> None: + async def force_refresh(self, block: bool = False) -> None: """ Invalidates the cache and configures the next call to connect_info() to retrieve a fresh ConnectionInfo instance. diff --git a/google/cloud/alloydbconnector/static.py b/google/cloud/alloydbconnector/static.py index 18555a15..1eeb2bc3 100644 --- a/google/cloud/alloydbconnector/static.py +++ b/google/cloud/alloydbconnector/static.py @@ -81,7 +81,7 @@ def __init__(self, instance_uri: str, static_conn_info: io.TextIOBase) -> None: cert_chain, ca_cert, priv_key_bytes, ip_addrs, expiration ) - async def force_refresh(self) -> None: + async def force_refresh(self, block: bool = False) -> None: """ This is a no-op as the cache holds only static connection information and does no refresh. diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 2dfbd9a4..d17bcd65 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -386,6 +386,7 @@ class FakeConnectionInfo: def __init__(self) -> None: self._close_called = False self._force_refresh_called = False + self._force_refresh_blocking = False def connect_info(self) -> Any: f = asyncio.Future() @@ -400,8 +401,9 @@ def get_preferred_ip(self, ip_type: Any) -> tuple[str, Any]: async def create_ssl_context(self) -> None: return None - async def force_refresh(self) -> None: + async def force_refresh(self, block: bool) -> None: self._force_refresh_called = True + self._force_refresh_blocking = block async def close(self) -> None: self._close_called = True diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py index fb5076bd..1c21b36c 100644 --- a/tests/unit/test_async_connector.py +++ b/tests/unit/test_async_connector.py @@ -397,3 +397,30 @@ async def test_connect_when_closed(credentials: FakeCredentials) -> None: exc_info.value.args[0] == "Connection attempt failed because the connector has already been closed." ) + + +async def test_connect_after_force_refresh( + credentials: FakeCredentials, fake_client: FakeAlloyDBClient +) -> None: + """ + Test that connector.connect can succeed after force refreshing its cache. + """ + async with AsyncConnector(credentials) as connector: + fake = FakeConnectionInfo() + connector._cache[TEST_INSTANCE_NAME] = fake + + connector._client = fake_client + # patch db connection creation + with patch("google.cloud.alloydbconnector.asyncpg.connect") as mock_connect: + mock_connect.side_effect = [Exception(), True] + connection = await connector.connect( + TEST_INSTANCE_NAME, + "asyncpg", + user="test-user", + password="test-password", + db="test-db", + ) + # check connection after force refreshing cache + assert connection is True + assert fake._force_refresh_called is True + assert fake._force_refresh_blocking is True diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 349120c6..e95e1a2e 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -337,3 +337,26 @@ def test_connect_when_closed(credentials: FakeCredentials) -> None: exc_info.value.args[0] == "Connection attempt failed because the connector has already been closed." ) + + +@pytest.mark.usefixtures("proxy_server") +def test_connect_after_force_refresh( + credentials: FakeCredentials, fake_client: FakeAlloyDBClient +) -> None: + """ + Test that connector.connect can succeed after force refreshing its cache. + """ + with Connector(credentials) as connector: + connector._client = fake_client + # patch db connection creation + with patch("google.cloud.alloydbconnector.pg8000.connect") as mock_connect: + mock_connect.side_effect = [Exception(), True] + connection = connector.connect( + "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", + "pg8000", + user="test-user", + password="test-password", + db="test-db", + ) + # check connection is returned after force refreshing cache + assert connection is True diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index b643d836..d976fcf6 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -246,3 +246,25 @@ async def test_force_refresh_cancels_pending_refresh() -> None: assert isinstance(await cache._current, ConnectionInfo) # close instance await cache.close() + + +@pytest.mark.asyncio +async def test_force_refresh_after_blocking_sets_current_to_next() -> None: + """ + Test that force_refresh sets the current task to the next task. + """ + keys = asyncio.create_task(generate_keys()) + client = FakeAlloyDBClient() + cache = RefreshAheadCache( + "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", + client, + keys, + ) + # make sure initial refresh is finished + await cache._current + + assert cache._current != cache._next + await cache.force_refresh(True) + assert cache._current == cache._next + # close instance + await cache.close()