diff --git a/README.md b/README.md index 1f0e633b9..d79e706d3 100644 --- a/README.md +++ b/README.md @@ -428,6 +428,44 @@ with Connector(resolver=DnsResolver) as connector: # ... use SQLAlchemy engine normally ``` +### Automatic failover using DNS domain names + +> [!NOTE] +> +> Usage of the `asyncpg` driver does not currently support automatic failover. + +When the connector is configured using a domain name, the connector will +periodically check if the DNS record for an instance changes. When the connector +detects that the domain name refers to a different instance, the connector will +close all open connections to the old instance. Subsequent connection attempts +will be directed to the new instance. + +For example: suppose application is configured to connect using the +domain name `prod-db.mycompany.example.com`. Initially the private DNS +zone has a TXT record with the value `my-project:region:my-instance`. The +application establishes connections to the `my-project:region:my-instance` +Cloud SQL instance. + +Then, to reconfigure the application to use a different database +instance, change the value of the `prod-db.mycompany.example.com` DNS record +from `my-project:region:my-instance` to `my-project:other-region:my-instance-2` + +The connector inside the application detects the change to this +DNS record. Now, when the application connects to its database using the +domain name `prod-db.mycompany.example.com`, it will connect to the +`my-project:other-region:my-instance-2` Cloud SQL instance. + +The connector will automatically close all existing connections to +`my-project:region:my-instance`. This will force the connection pools to +establish new connections. Also, it may cause database queries in progress +to fail. + +The connector will poll for changes to the DNS name every 30 seconds by default. +You may configure the frequency of the connections using the Connector's +`failover_period` argument (i.e. `Connector(failover_period=60`). When this is +set to 0, the connector will disable polling and only check if the DNS record +changed when it is creating a new connection. + ### Using the Python Connector with Python Web Frameworks The Python Connector can be used alongside popular Python web frameworks such diff --git a/google/cloud/sql/connector/__init__.py b/google/cloud/sql/connector/__init__.py index 99a5097a2..6913337d3 100644 --- a/google/cloud/sql/connector/__init__.py +++ b/google/cloud/sql/connector/__init__.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2019 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index 82e3a9018..c9e48935f 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -14,6 +14,7 @@ from __future__ import annotations +import abc from dataclasses import dataclass import logging import ssl @@ -34,6 +35,27 @@ logger = logging.getLogger(name=__name__) +class ConnectionInfoCache(abc.ABC): + """Abstract class for Connector connection info caches.""" + + @abc.abstractmethod + async def connect_info(self) -> ConnectionInfo: + pass + + @abc.abstractmethod + async def force_refresh(self) -> None: + pass + + @abc.abstractmethod + async def close(self) -> None: + pass + + @property + @abc.abstractmethod + def closed(self) -> bool: + pass + + @dataclass class ConnectionInfo: """Contains all necessary information to connect securely to the diff --git a/google/cloud/sql/connector/connection_name.py b/google/cloud/sql/connector/connection_name.py index 437fd6607..ad5dc40fb 100644 --- a/google/cloud/sql/connector/connection_name.py +++ b/google/cloud/sql/connector/connection_name.py @@ -42,6 +42,10 @@ def __str__(self) -> str: return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}" return f"{self.project}:{self.region}:{self.instance_name}" + def get_connection_string(self) -> str: + """Get the instance connection string for the Cloud SQL instance.""" + return f"{self.project}:{self.region}:{self.instance_name}" + def _is_valid_domain(domain_name: str) -> bool: if DOMAIN_NAME_REGEX.fullmatch(domain_name) is None: diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 3e53e754a..c76092a40 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -20,9 +20,10 @@ from functools import partial import logging import os +import socket from threading import Thread from types import TracebackType -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import google.auth from google.auth.credentials import Credentials @@ -35,6 +36,7 @@ from google.cloud.sql.connector.enums import RefreshStrategy from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.lazy import LazyRefreshCache +from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds @@ -46,6 +48,7 @@ logger = logging.getLogger(name=__name__) ASYNC_DRIVERS = ["asyncpg"] +SERVER_PROXY_PORT = 3307 _DEFAULT_SCHEME = "https://" _DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" _SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}" @@ -67,6 +70,7 @@ def __init__( universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver, + failover_period: int = 30, ) -> None: """Initializes a Connector instance. @@ -114,6 +118,11 @@ def __init__( name. To resolve a DNS record to an instance connection name, use DnsResolver. Default: DefaultResolver + + failover_period (int): The time interval in seconds between each + attempt to check if a failover has occured for a given instance. + Must be used with `resolver=DnsResolver` to have any effect. + Default: 30 """ # if refresh_strategy is str, convert to RefreshStrategy enum if isinstance(refresh_strategy, str): @@ -143,9 +152,7 @@ def __init__( ) # initialize dict to store caches, key is a tuple consisting of instance # connection name string and enable_iam_auth boolean flag - self._cache: dict[ - tuple[str, bool], Union[RefreshAheadCache, LazyRefreshCache] - ] = {} + self._cache: dict[tuple[str, bool], MonitoredCache] = {} self._client: Optional[CloudSQLClient] = None # initialize credentials @@ -167,6 +174,7 @@ def __init__( self._enable_iam_auth = enable_iam_auth self._user_agent = user_agent self._resolver = resolver() + self._failover_period = failover_period # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) @@ -285,15 +293,19 @@ async def connect_async( driver=driver, ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) - if (instance_connection_string, enable_iam_auth) in self._cache: - cache = self._cache[(instance_connection_string, enable_iam_auth)] + + conn_name = await self._resolver.resolve(instance_connection_string) + # Cache entry must exist and not be closed + if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[ + (str(conn_name), enable_iam_auth) + ].closed: + monitored_cache = self._cache[(str(conn_name), enable_iam_auth)] else: - conn_name = await self._resolver.resolve(instance_connection_string) if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( f"['{conn_name}']: Refresh strategy is set to lazy refresh" ) - cache = LazyRefreshCache( + cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( conn_name, self._client, self._keys, @@ -309,8 +321,14 @@ async def connect_async( self._keys, enable_iam_auth, ) + # wrap cache as a MonitoredCache + monitored_cache = MonitoredCache( + cache, + self._failover_period, + self._resolver, + ) logger.debug(f"['{conn_name}']: Connection info added to cache") - self._cache[(instance_connection_string, enable_iam_auth)] = cache + self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache connect_func = { "pymysql": pymysql.connect, @@ -321,7 +339,7 @@ async def connect_async( # only accept supported database drivers try: - connector = connect_func[driver] + connector: Callable = connect_func[driver] # type: ignore except KeyError: raise KeyError(f"Driver '{driver}' is not supported.") @@ -339,14 +357,14 @@ async def connect_async( # attempt to get connection info for Cloud SQL instance try: - conn_info = await cache.connect_info() + conn_info = await monitored_cache.connect_info() # validate driver matches intended database engine DriverMapping.validate_engine(driver, conn_info.database_version) ip_address = conn_info.get_preferred_ip(ip_type) except Exception: # with an error from Cloud SQL Admin API call or IP type, invalidate # the cache and re-raise the error - await self._remove_cached(instance_connection_string, enable_iam_auth) + await self._remove_cached(str(conn_name), enable_iam_auth) raise logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") # format `user` param for automatic IAM database authn @@ -367,18 +385,28 @@ async def connect_async( await conn_info.create_ssl_context(enable_iam_auth), **kwargs, ) - # synchronous drivers are blocking and run using executor + # Create socket with SSLContext for sync drivers + ctx = await conn_info.create_ssl_context(enable_iam_auth) + sock = ctx.wrap_socket( + socket.create_connection((ip_address, SERVER_PROXY_PORT)), + server_hostname=ip_address, + ) + # If this connection was opened using a domain name, then store it + # for later in case we need to forcibly close it on failover. + if conn_info.conn_name.domain_name: + monitored_cache.sockets.append(sock) + # Synchronous drivers are blocking and run using executor connect_partial = partial( connector, ip_address, - await conn_info.create_ssl_context(enable_iam_auth), + sock, **kwargs, ) return await self._loop.run_in_executor(None, connect_partial) except Exception: # with any exception, we attempt a force refresh, then throw the error - await cache.force_refresh() + await monitored_cache.force_refresh() raise async def _remove_cached( @@ -456,6 +484,7 @@ async def create_async_connector( universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver, + failover_period: int = 30, ) -> Connector: """Helper function to create Connector object for asyncio connections. @@ -507,6 +536,11 @@ async def create_async_connector( DnsResolver. Default: DefaultResolver + failover_period (int): The time interval in seconds between each + attempt to check if a failover has occured for a given instance. + Must be used with `resolver=DnsResolver` to have any effect. + Default: 30 + Returns: A Connector instance configured with running event loop. """ @@ -525,4 +559,5 @@ async def create_async_connector( universe_domain=universe_domain, refresh_strategy=refresh_strategy, resolver=resolver, + failover_period=failover_period, ) diff --git a/google/cloud/sql/connector/exceptions.py b/google/cloud/sql/connector/exceptions.py index 92e3e5662..da39ea25d 100644 --- a/google/cloud/sql/connector/exceptions.py +++ b/google/cloud/sql/connector/exceptions.py @@ -77,3 +77,10 @@ class DnsResolutionError(Exception): Exception to be raised when an instance connection name can not be resolved from a DNS record. """ + + +class CacheClosedError(Exception): + """ + Exception to be raised when a ConnectionInfoCache can not be accessed after + it is closed. + """ diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 5df272fe2..fb8711309 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -24,6 +24,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_info import ConnectionInfoCache from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import RefreshNotValidError from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter @@ -35,7 +36,7 @@ APPLICATION_NAME = "cloud-sql-python-connector" -class RefreshAheadCache: +class RefreshAheadCache(ConnectionInfoCache): """Cache that refreshes connection info in the background prior to expiration. Background tasks are used to schedule refresh attempts to get a new @@ -74,6 +75,15 @@ def __init__( self._refresh_in_progress = asyncio.locks.Event() self._current: asyncio.Task = self._schedule_refresh(0) self._next: asyncio.Task = self._current + self._closed = False + + @property + def conn_name(self) -> ConnectionName: + return self._conn_name + + @property + def closed(self) -> bool: + return self._closed async def force_refresh(self) -> None: """ @@ -212,3 +222,4 @@ async def close(self) -> None: # gracefully wait for tasks to cancel tasks = asyncio.gather(self._current, self._next, return_exceptions=True) await asyncio.wait_for(tasks, timeout=2.0) + self._closed = True diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py index 1bc4f90f8..c75d07e52 100644 --- a/google/cloud/sql/connector/lazy.py +++ b/google/cloud/sql/connector/lazy.py @@ -21,13 +21,14 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_info import ConnectionInfoCache from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.refresh_utils import _refresh_buffer logger = logging.getLogger(name=__name__) -class LazyRefreshCache: +class LazyRefreshCache(ConnectionInfoCache): """Cache that refreshes connection info when a caller requests a connection. Only refreshes the cache when a new connection is requested and the current @@ -62,6 +63,15 @@ def __init__( self._lock = asyncio.Lock() self._cached: Optional[ConnectionInfo] = None self._needs_refresh = False + self._closed = False + + @property + def conn_name(self) -> ConnectionName: + return self._conn_name + + @property + def closed(self) -> bool: + return self._closed async def force_refresh(self) -> None: """ @@ -121,4 +131,5 @@ async def close(self) -> None: """Close is a no-op and provided purely for a consistent interface with other cache types. """ - pass + self._closed = True + return diff --git a/google/cloud/sql/connector/monitored_cache.py b/google/cloud/sql/connector/monitored_cache.py new file mode 100644 index 000000000..0c3fc4d03 --- /dev/null +++ b/google/cloud/sql/connector/monitored_cache.py @@ -0,0 +1,146 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import ssl +from typing import Any, Callable, Optional, Union + +from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_info import ConnectionInfoCache +from google.cloud.sql.connector.exceptions import CacheClosedError +from google.cloud.sql.connector.instance import RefreshAheadCache +from google.cloud.sql.connector.lazy import LazyRefreshCache +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver + +logger = logging.getLogger(name=__name__) + + +class MonitoredCache(ConnectionInfoCache): + def __init__( + self, + cache: Union[RefreshAheadCache, LazyRefreshCache], + failover_period: int, + resolver: Union[DefaultResolver, DnsResolver], + ) -> None: + self.resolver = resolver + self.cache = cache + self.domain_name_ticker: Optional[asyncio.Task] = None + self.sockets: list[ssl.SSLSocket] = [] + + # If domain name is configured for instance and failover period is set, + # poll for DNS record changes. + if self.cache.conn_name.domain_name and failover_period > 0: + self.domain_name_ticker = asyncio.create_task( + ticker(failover_period, self._check_domain_name) + ) + logger.debug( + f"['{self.cache.conn_name}']: Configured polling of domain " + f"name with failover period of {failover_period} seconds." + ) + + @property + def closed(self) -> bool: + return self.cache.closed + + def _purge_closed_sockets(self) -> None: + """Remove closed sockets from monitored cache. + + If a socket is closed by the database driver we should remove it from + list of sockets. + """ + open_sockets = [] + for socket in self.sockets: + # Check fileno for if socket is closed. Will return + # -1 on failure, which will be used to signal socket closed. + if socket.fileno() != -1: + open_sockets.append(socket) + self.sockets = open_sockets + + async def _check_domain_name(self) -> None: + # remove any closed connections from cache + self._purge_closed_sockets() + try: + # Resolve domain name and see if Cloud SQL instance connection name + # has changed. If it has, close all connections. + new_conn_name = await self.resolver.resolve( + self.cache.conn_name.domain_name + ) + if new_conn_name != self.cache.conn_name: + logger.debug( + f"['{self.cache.conn_name}']: Cloud SQL instance changed " + f"from {self.cache.conn_name.get_connection_string()} to " + f"{new_conn_name.get_connection_string()}, closing all " + "connections!" + ) + await self.close() + + except Exception as e: + # Domain name checks should not be fatal, log error and continue. + logger.debug( + f"['{self.cache.conn_name}']: Unable to check domain name, " + f"domain name {self.cache.conn_name.domain_name} did not " + f"resolve: {e}" + ) + + async def connect_info(self) -> ConnectionInfo: + if self.closed: + raise CacheClosedError( + "Can not get connection info, cache has already been closed." + ) + return await self.cache.connect_info() + + async def force_refresh(self) -> None: + # if cache is closed do not refresh + if self.closed: + return + return await self.cache.force_refresh() + + async def close(self) -> None: + # Cancel domain name ticker task. + if self.domain_name_ticker: + self.domain_name_ticker.cancel() + try: + await self.domain_name_ticker + except asyncio.CancelledError: + logger.debug( + f"['{self.cache.conn_name}']: Cancelled domain name polling task." + ) + finally: + self.domain_name_ticker = None + # If cache is already closed, no further work. + if self.closed: + return + + # Close underyling ConnectionInfoCache + await self.cache.close() + + # Close any still open sockets + for socket in self.sockets: + # Check fileno for if socket is closed. Will return + # -1 on failure, which will be used to signal socket closed. + if socket.fileno() != -1: + socket.close() + + +async def ticker(interval: int, function: Callable, *args: Any, **kwargs: Any) -> None: + """ + Ticker function to sleep for specified interval and then schedule call + to given function. + """ + while True: + # Sleep for interval and then schedule task + await asyncio.sleep(interval) + asyncio.create_task(function(*args, **kwargs)) diff --git a/google/cloud/sql/connector/pg8000.py b/google/cloud/sql/connector/pg8000.py index 1f66dde2a..baaee6615 100644 --- a/google/cloud/sql/connector/pg8000.py +++ b/google/cloud/sql/connector/pg8000.py @@ -14,18 +14,15 @@ limitations under the License. """ -import socket import ssl from typing import Any, TYPE_CHECKING -SERVER_PROXY_PORT = 3307 - if TYPE_CHECKING: import pg8000 def connect( - ip_address: str, ctx: ssl.SSLContext, **kwargs: Any + ip_address: str, sock: ssl.SSLSocket, **kwargs: Any ) -> "pg8000.dbapi.Connection": """Helper function to create a pg8000 DB-API connection object. @@ -33,8 +30,8 @@ def connect( :param ip_address: A string containing an IP address for the Cloud SQL instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA + :type sock: ssl.SSLSocket + :param sock: An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. @@ -48,12 +45,6 @@ def connect( 'Unable to import module "pg8000." Please install and try again.' ) - # Create socket and wrap with context. - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) - user = kwargs.pop("user") db = kwargs.pop("db") passwd = kwargs.pop("password", None) diff --git a/google/cloud/sql/connector/pymysql.py b/google/cloud/sql/connector/pymysql.py index a16584367..f83f7076c 100644 --- a/google/cloud/sql/connector/pymysql.py +++ b/google/cloud/sql/connector/pymysql.py @@ -14,18 +14,15 @@ limitations under the License. """ -import socket import ssl from typing import Any, TYPE_CHECKING -SERVER_PROXY_PORT = 3307 - if TYPE_CHECKING: import pymysql def connect( - ip_address: str, ctx: ssl.SSLContext, **kwargs: Any + ip_address: str, sock: ssl.SSLSocket, **kwargs: Any ) -> "pymysql.connections.Connection": """Helper function to create a pymysql DB-API connection object. @@ -33,8 +30,8 @@ def connect( :param ip_address: A string containing an IP address for the Cloud SQL instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA + :type sock: ssl.SSLSocket + :param sock: An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. :rtype: pymysql.Connection @@ -50,11 +47,6 @@ def connect( # allow automatic IAM database authentication to not require password kwargs["password"] = kwargs["password"] if "password" in kwargs else None - # Create socket and wrap with context. - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) # pop timeout as timeout arg is called 'connect_timeout' for pymysql timeout = kwargs.pop("timeout") kwargs["connect_timeout"] = kwargs.get("connect_timeout", timeout) diff --git a/google/cloud/sql/connector/pytds.py b/google/cloud/sql/connector/pytds.py index 243d90fd5..3128fdb6a 100644 --- a/google/cloud/sql/connector/pytds.py +++ b/google/cloud/sql/connector/pytds.py @@ -15,27 +15,24 @@ """ import platform -import socket import ssl from typing import Any, TYPE_CHECKING from google.cloud.sql.connector.exceptions import PlatformNotSupportedError -SERVER_PROXY_PORT = 3307 - if TYPE_CHECKING: import pytds -def connect(ip_address: str, ctx: ssl.SSLContext, **kwargs: Any) -> "pytds.Connection": +def connect(ip_address: str, sock: ssl.SSLSocket, **kwargs: Any) -> "pytds.Connection": """Helper function to create a pytds DB-API connection object. :type ip_address: str :param ip_address: A string containing an IP address for the Cloud SQL instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA + :type sock: ssl.SSLSocket + :param sock: An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. @@ -51,11 +48,6 @@ def connect(ip_address: str, ctx: ssl.SSLContext, **kwargs: Any) -> "pytds.Conne db = kwargs.pop("db", None) - # Create socket and wrap with context. - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) if kwargs.pop("active_directory_auth", False): if platform.system() == "Windows": # Ignore username and password if using active directory auth diff --git a/tests/conftest.py b/tests/conftest.py index 3a1a38a27..c75de48cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/system/test_connector_object.py b/tests/system/test_connector_object.py index c2b5cf125..258b80aaf 100644 --- a/tests/system/test_connector_object.py +++ b/tests/system/test_connector_object.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/system/test_ip_types.py b/tests/system/test_ip_types.py index 2df3b1df5..3af49c54f 100644 --- a/tests/system/test_ip_types.py +++ b/tests/system/test_ip_types.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/system/test_pymysql_connection.py b/tests/system/test_pymysql_connection.py index 490b1fab4..1e7e26830 100644 --- a/tests/system/test_pymysql_connection.py +++ b/tests/system/test_pymysql_connection.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/system/test_pytds_connection.py b/tests/system/test_pytds_connection.py index d848abc18..fd88d230f 100644 --- a/tests/system/test_pytds_connection.py +++ b/tests/system/test_pytds_connection.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/unit/test_connection_name.py b/tests/unit/test_connection_name.py index 218034d51..0861d8245 100644 --- a/tests/unit/test_connection_name.py +++ b/tests/unit/test_connection_name.py @@ -31,6 +31,8 @@ def test_ConnectionName() -> None: assert conn_name.domain_name == "" # test ConnectionName str() method prints instance connection name assert str(conn_name) == "project:region:instance" + # test ConnectionName.get_connection_string + assert conn_name.get_connection_string() == "project:region:instance" def test_ConnectionName_with_domain_name() -> None: @@ -42,6 +44,8 @@ def test_ConnectionName_with_domain_name() -> None: assert conn_name.domain_name == "db.example.com" # test ConnectionName str() method prints with domain name assert str(conn_name) == "db.example.com -> project:region:instance" + # test ConnectionName.get_connection_string + assert conn_name.get_connection_string() == "project:region:instance" @pytest.mark.parametrize( diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index e25c9a384..498c947cc 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index aeedf3399..1a3d60917 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2019 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/unit/test_lazy.py b/tests/unit/test_lazy.py index 344b073e8..c6eef7509 100644 --- a/tests/unit/test_lazy.py +++ b/tests/unit/test_lazy.py @@ -21,6 +21,27 @@ from google.cloud.sql.connector.utils import generate_keys +async def test_LazyRefreshCache_properties(fake_client: CloudSQLClient) -> None: + """ + Test that LazyRefreshCache properties work as expected. + """ + keys = asyncio.create_task(generate_keys()) + conn_name = ConnectionName("test-project", "test-region", "test-instance") + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=keys, + enable_iam_auth=False, + ) + # test conn_name property + assert cache.conn_name == conn_name + # test closed property + assert cache.closed is False + # close cache and make sure property is updated + await cache.close() + assert cache.closed is True + + async def test_LazyRefreshCache_connect_info(fake_client: CloudSQLClient) -> None: """ Test that LazyRefreshCache.connect_info works as expected. diff --git a/tests/unit/test_monitored_cache.py b/tests/unit/test_monitored_cache.py new file mode 100644 index 000000000..1eea4eb46 --- /dev/null +++ b/tests/unit/test_monitored_cache.py @@ -0,0 +1,240 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import socket + +import dns.message +import dns.rdataclass +import dns.rdatatype +import dns.resolver +from mock import patch +from mocks import create_ssl_context +import pytest + +from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import CacheClosedError +from google.cloud.sql.connector.lazy import LazyRefreshCache +from google.cloud.sql.connector.monitored_cache import MonitoredCache +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver +from google.cloud.sql.connector.utils import generate_keys + +query_text = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +db.example.com. IN TXT +;ANSWER +db.example.com. 0 IN TXT "test-project:test-region:test-instance" +;AUTHORITY +;ADDITIONAL +""" + + +async def test_MonitoredCache_properties(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache properties work as expected. + """ + conn_name = ConnectionName("test-project", "test-region", "test-instance") + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + monitored_cache = MonitoredCache(cache, 30, DefaultResolver()) + # test that ticker is not set for instance not using domain name + assert monitored_cache.domain_name_ticker is None + # test closed property + assert monitored_cache.closed is False + # close cache and make sure property is updated + await monitored_cache.close() + assert monitored_cache.closed is True + + +async def test_MonitoredCache_CacheClosedError(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache.connect_info errors when cache is closed. + """ + conn_name = ConnectionName("test-project", "test-region", "test-instance") + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + monitored_cache = MonitoredCache(cache, 30, DefaultResolver()) + # test closed property + assert monitored_cache.closed is False + # close cache and make sure property is updated + await monitored_cache.close() + assert monitored_cache.closed is True + # attempt to get connect info from closed cache + with pytest.raises(CacheClosedError): + await monitored_cache.connect_info() + + +async def test_MonitoredCache_with_DnsResolver(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache with DnsResolver work as expected. + """ + conn_name = ConnectionName( + "test-project", "test-region", "test-instance", "db.example.com" + ) + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + # Patch DNS resolution with valid TXT records + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "db.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + monitored_cache = MonitoredCache(cache, 30, resolver) + # test that ticker is set for instance using domain name + assert type(monitored_cache.domain_name_ticker) is asyncio.Task + # test closed property + assert monitored_cache.closed is False + # close cache and make sure property is updated + await monitored_cache.close() + assert monitored_cache.closed is True + # domain name ticker should be set back to None + assert monitored_cache.domain_name_ticker is None + + +async def test_MonitoredCache_with_disabled_failover( + fake_client: CloudSQLClient, +) -> None: + """ + Test that MonitoredCache disables DNS polling with failover_period=0 + """ + conn_name = ConnectionName( + "test-project", "test-region", "test-instance", "db.example.com" + ) + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + monitored_cache = MonitoredCache(cache, 0, DnsResolver()) + # test that ticker is not set when failover is disabled + assert monitored_cache.domain_name_ticker is None + # test closed property + assert monitored_cache.closed is False + # close cache and make sure property is updated + await monitored_cache.close() + assert monitored_cache.closed is True + + +@pytest.mark.usefixtures("server") +async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache is closed when _check_domain_name has domain change. + """ + conn_name = ConnectionName( + "my-project", "my-region", "my-instance", "db.example.com" + ) + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + # Patch DNS resolution with valid TXT records + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "db.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + + # configure a local socket + ip_addr = "127.0.0.1" + context = await create_ssl_context() + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, + ) + # verify socket is open + assert sock.fileno() != -1 + # set failover to 0 to disable polling + monitored_cache = MonitoredCache(cache, 0, resolver) + # add socket to cache + monitored_cache.sockets = [sock] + # check cache is not closed + assert monitored_cache.closed is False + # call _check_domain_name and verify cache is closed + await monitored_cache._check_domain_name() + assert monitored_cache.closed is True + # verify socket was closed + assert sock.fileno() == -1 + + +@pytest.mark.usefixtures("server") +async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache._purge_closed_sockets removes closed sockets from + cache. + """ + conn_name = ConnectionName( + "my-project", "my-region", "my-instance", "db.example.com" + ) + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + # configure a local socket + ip_addr = "127.0.0.1" + context = await create_ssl_context() + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, + ) + + # set failover to 0 to disable polling + monitored_cache = MonitoredCache(cache, 0, DnsResolver()) + # verify socket is open + assert sock.fileno() != -1 + # add socket to cache + monitored_cache.sockets = [sock] + # call _purge_closed_sockets and verify socket remains + monitored_cache._purge_closed_sockets() + # verify socket is still open + assert sock.fileno() != -1 + assert len(monitored_cache.sockets) == 1 + # close socket + sock.close() + # call _purge_closed_sockets and verify socket is removed + monitored_cache._purge_closed_sockets() + assert len(monitored_cache.sockets) == 0 diff --git a/tests/unit/test_pg8000.py b/tests/unit/test_pg8000.py index 1b2adbb65..e01a53445 100644 --- a/tests/unit/test_pg8000.py +++ b/tests/unit/test_pg8000.py @@ -14,7 +14,7 @@ limitations under the License. """ -from functools import partial +import socket from typing import Any from mock import patch @@ -31,15 +31,14 @@ async def test_pg8000(kwargs: Any) -> None: ip_addr = "127.0.0.1" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) with patch("pg8000.dbapi.connect") as mock_connect: mock_connect.return_value = True - connection = connect(ip_addr, context, **kwargs) + connection = connect(ip_addr, sock, **kwargs) assert connection is True # verify that driver connection call would be made assert mock_connect.assert_called_once diff --git a/tests/unit/test_pymysql.py b/tests/unit/test_pymysql.py index 69d2aba8f..66b1f22a3 100644 --- a/tests/unit/test_pymysql.py +++ b/tests/unit/test_pymysql.py @@ -14,7 +14,7 @@ limitations under the License. """ -from functools import partial +import socket import ssl from typing import Any @@ -40,15 +40,14 @@ async def test_pymysql(kwargs: Any) -> None: ip_addr = "127.0.0.1" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) kwargs["timeout"] = 30 with patch("pymysql.Connection") as mock_connect: mock_connect.return_value = MockConnection - pymysql_connect(ip_addr, context, **kwargs) + pymysql_connect(ip_addr, sock, **kwargs) # verify that driver connection call would be made assert mock_connect.assert_called_once diff --git a/tests/unit/test_pytds.py b/tests/unit/test_pytds.py index 633aab74a..9efe00ee5 100644 --- a/tests/unit/test_pytds.py +++ b/tests/unit/test_pytds.py @@ -14,8 +14,8 @@ limitations under the License. """ -from functools import partial import platform +import socket from typing import Any from mock import patch @@ -43,16 +43,15 @@ async def test_pytds(kwargs: Any) -> None: ip_addr = "127.0.0.1" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) with patch("pytds.connect") as mock_connect: mock_connect.return_value = True - connection = connect(ip_addr, context, **kwargs) + connection = connect(ip_addr, sock, **kwargs) # verify that driver connection call would be made assert connection is True assert mock_connect.assert_called_once @@ -68,17 +67,16 @@ async def test_pytds_platform_error(kwargs: Any) -> None: assert platform.system() == "Linux" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) # add active_directory_auth to kwargs kwargs["active_directory_auth"] = True # verify that error is thrown with Linux and active_directory_auth with pytest.raises(PlatformNotSupportedError): - connect(ip_addr, context, **kwargs) + connect(ip_addr, sock, **kwargs) @pytest.mark.usefixtures("server") @@ -94,11 +92,10 @@ async def test_pytds_windows_active_directory_auth(kwargs: Any) -> None: assert platform.system() == "Windows" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) # add active_directory_auth and server_name to kwargs kwargs["active_directory_auth"] = True @@ -107,7 +104,7 @@ async def test_pytds_windows_active_directory_auth(kwargs: Any) -> None: mock_connect.return_value = True with patch("pytds.login.SspiAuth") as mock_login: mock_login.return_value = True - connection = connect(ip_addr, context, **kwargs) + connection = connect(ip_addr, sock, **kwargs) # verify that driver connection call would be made assert mock_login.assert_called_once assert connection is True diff --git a/tests/unit/test_rate_limiter.py b/tests/unit/test_rate_limiter.py index 5e187b81d..8ef586b58 100644 --- a/tests/unit/test_rate_limiter.py +++ b/tests/unit/test_rate_limiter.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 6545bc7a8..fe4e90955 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2019 Google LLC Licensed under the Apache License, Version 2.0 (the "License");