diff --git a/.cl/cloudbuild.yaml b/.cl/cloudbuild.yaml index 84d367122..7fa63a487 100644 --- a/.cl/cloudbuild.yaml +++ b/.cl/cloudbuild.yaml @@ -13,16 +13,76 @@ # limitations under the License. steps: - - id: ping-google - name: alpine:3.10 - entrypoint: ping + - id: run integration tests ${_EARLIEST_PYTHON_VERSION} + name: python:${_EARLIEST_PYTHON_VERSION} + entrypoint: bash + env: + - "IP_TYPE=private" + secretEnv: ["MYSQL_CONNECTION_NAME", "MYSQL_USER", "MYSQL_IAM_USER", "MYSQL_PASS", "MYSQL_DB", "POSTGRES_CONNECTION_NAME", "POSTGRES_USER", "POSTGRES_IAM_USER", "POSTGRES_PASS", "POSTGRES_DB", "POSTGRES_CAS_CONNECTION_NAME", "POSTGRES_CAS_PASS", "POSTGRES_CUSTOMER_CAS_CONNECTION_NAME", "POSTGRES_CUSTOMER_CAS_PASS", "POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME","SQLSERVER_CONNECTION_NAME", "SQLSERVER_USER", "SQLSERVER_PASS", "SQLSERVER_DB"] args: - - -c - - "4" - - google.com - + - "-c" + - | + pip install nox + nox -s system-${_EARLIEST_PYTHON_VERSION} + waitFor: ["-"] + - id: run integration tests ${_LATEST_PYTHON_VERSION} + name: python:${_LATEST_PYTHON_VERSION} + entrypoint: bash + env: + - "IP_TYPE=private" + secretEnv: ["MYSQL_CONNECTION_NAME", "MYSQL_USER", "MYSQL_IAM_USER", "MYSQL_PASS", "MYSQL_DB", "POSTGRES_CONNECTION_NAME", "POSTGRES_USER", "POSTGRES_IAM_USER", "POSTGRES_PASS", "POSTGRES_DB", "POSTGRES_CAS_CONNECTION_NAME", "POSTGRES_CAS_PASS", "POSTGRES_CUSTOMER_CAS_CONNECTION_NAME", "POSTGRES_CUSTOMER_CAS_PASS", "POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME","SQLSERVER_CONNECTION_NAME", "SQLSERVER_USER", "SQLSERVER_PASS", "SQLSERVER_DB"] + args: + - "-c" + - | + pip install nox + nox -s system-${_LATEST_PYTHON_VERSION} + waitFor: ["-"] +availableSecrets: + secretManager: + - versionName: 'projects/$PROJECT_ID/secrets/MYSQL_CONNECTION_NAME/versions/latest' + env: 'MYSQL_CONNECTION_NAME' + - versionName: 'projects/$PROJECT_ID/secrets/MYSQL_USER/versions/latest' + env: 'MYSQL_USER' + - versionName: 'projects/$PROJECT_ID/secrets/CLOUD_BUILD_MYSQL_IAM_USER/versions/latest' + env: 'MYSQL_IAM_USER' + - versionName: 'projects/$PROJECT_ID/secrets/MYSQL_PASS/versions/latest' + env: 'MYSQL_PASS' + - versionName: 'projects/$PROJECT_ID/secrets/MYSQL_DB/versions/latest' + env: 'MYSQL_DB' + - versionName: 'projects/$PROJECT_ID/secrets/POSTGRES_CONNECTION_NAME/versions/latest' + env: 'POSTGRES_CONNECTION_NAME' + - versionName: 'projects/$PROJECT_ID/secrets/POSTGRES_USER/versions/latest' + env: 'POSTGRES_USER' + - versionName: 'projects/$PROJECT_ID/secrets/CLOUD_BUILD_POSTGRES_IAM_USER/versions/latest' + env: 'POSTGRES_IAM_USER' + - versionName: 'projects/$PROJECT_ID/secrets/POSTGRES_PASS/versions/latest' + env: 'POSTGRES_PASS' + - versionName: 'projects/$PROJECT_ID/secrets/POSTGRES_DB/versions/latest' + env: 'POSTGRES_DB' + - versionName: 'projects/$PROJECT_ID/secrets/POSTGRES_CAS_CONNECTION_NAME/versions/latest' + env: 'POSTGRES_CAS_CONNECTION_NAME' + - versionName: 'projects/$PROJECT_ID/secrets/POSTGRES_CAS_PASS/versions/latest' + env: 'POSTGRES_CAS_PASS' + - versionName: 'projects/$PROJECT_ID/secrets/POSTGRES_CUSTOMER_CAS_CONNECTION_NAME/versions/latest' + env: 'POSTGRES_CUSTOMER_CAS_CONNECTION_NAME' + - versionName: 'projects/$PROJECT_ID/secrets/POSTGRES_CUSTOMER_CAS_PASS/versions/latest' + env: 'POSTGRES_CUSTOMER_CAS_PASS' + - versionName: 'projects/$PROJECT_ID/secrets/POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME/versions/latest' + env: 'POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME' + - versionName: 'projects/$PROJECT_ID/secrets/SQLSERVER_CONNECTION_NAME/versions/latest' + env: 'SQLSERVER_CONNECTION_NAME' + - versionName: 'projects/$PROJECT_ID/secrets/SQLSERVER_USER/versions/latest' + env: 'SQLSERVER_USER' + - versionName: 'projects/$PROJECT_ID/secrets/SQLSERVER_PASS/versions/latest' + env: 'SQLSERVER_PASS' + - versionName: 'projects/$PROJECT_ID/secrets/SQLSERVER_DB/versions/latest' + env: 'SQLSERVER_DB' +substitutions: + _LATEST_PYTHON_VERSION: '3.13' + _EARLIEST_PYTHON_VERSION: '3.9' + options: dynamicSubstitutions: true pool: - name: $_POOL_NAME - logging: CLOUD_LOGGING_ONLY \ No newline at end of file + name: ${_POOL_NAME} + logging: CLOUD_LOGGING_ONLY diff --git a/.github/blunderbuss.yml b/.github/blunderbuss.yml index ded14a19b..bc2a8baf0 100644 --- a/.github/blunderbuss.yml +++ b/.github/blunderbuss.yml @@ -14,6 +14,7 @@ assign_issues: - jackwotherspoon + - kgala2 assign_prs: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e710138f6..b8e6eb58d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -81,6 +81,7 @@ jobs: POSTGRES_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CAS_PASS POSTGRES_CUSTOMER_CAS_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_CONNECTION_NAME POSTGRES_CUSTOMER_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS + POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME SQLSERVER_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_CONNECTION_NAME SQLSERVER_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_USER SQLSERVER_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_PASS @@ -102,6 +103,7 @@ jobs: POSTGRES_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CAS_PASS }}" POSTGRES_CUSTOMER_CAS_CONNECTION_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_CONNECTION_NAME }}" POSTGRES_CUSTOMER_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS }}" + POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME }}" SQLSERVER_CONNECTION_NAME: "${{ steps.secrets.outputs.SQLSERVER_CONNECTION_NAME }}" SQLSERVER_USER: "${{ steps.secrets.outputs.SQLSERVER_USER }}" SQLSERVER_PASS: "${{ steps.secrets.outputs.SQLSERVER_PASS }}" diff --git a/.gitignore b/.gitignore index 9ef6a9067..9f449ce4a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ venv .python-version cloud_sql_python_connector.egg-info/ dist/ +.idea +.coverage +sponge_log.xml diff --git a/CHANGELOG.md b/CHANGELOG.md index c812c21bf..47c853bca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## [1.18.0](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/compare/v1.17.0...v1.18.0) (2025-03-21) + + +### Features + +* add domain name validation ([#1246](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/1246)) ([15934bd](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/commit/15934bd18ab426edd19af67be799876b52895a48)) +* reset connection when the DNS record changes ([#1241](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/1241)) ([1405f56](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/commit/1405f564019f6a30a15535ed2e0d1dc108f38195)) + ## [1.17.0](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/compare/v1.16.0...v1.17.0) (2025-02-12) diff --git a/README.md b/README.md index 1f0e633b9..1c5489e04 100644 --- a/README.md +++ b/README.md @@ -126,21 +126,16 @@ import sqlalchemy # initialize Connector object connector = Connector() -# function to return the database connection -def getconn() -> pymysql.connections.Connection: - conn: pymysql.connections.Connection = connector.connect( +# initialize SQLAlchemy connection pool with Connector +pool = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=lambda: connector.connect( "project:region:instance", "pymysql", user="my-user", password="my-password", db="my-db-name" - ) - return conn - -# create connection pool -pool = sqlalchemy.create_engine( - "mysql+pymysql://", - creator=getconn, + ), ) ``` @@ -207,33 +202,21 @@ Connector as a context manager: ```python from google.cloud.sql.connector import Connector -import pymysql import sqlalchemy -# helper function to return SQLAlchemy connection pool -def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: - # function used to generate database connection - def getconn() -> pymysql.connections.Connection: - conn = connector.connect( +# initialize Cloud SQL Python Connector as context manager +with Connector() as connector: + # initialize SQLAlchemy connection pool with Connector + pool = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=lambda: connector.connect( "project:region:instance", "pymysql", user="my-user", password="my-password", db="my-db-name" - ) - return conn - - # create connection pool - pool = sqlalchemy.create_engine( - "mysql+pymysql://", - creator=getconn, + ), ) - return pool - -# initialize Cloud SQL Python Connector as context manager -with Connector() as connector: - # initialize connection pool - pool = init_connection_pool(connector) # insert statement insert_stmt = sqlalchemy.text( "INSERT INTO my_table (id, title) VALUES (:id, :title)", @@ -401,33 +384,60 @@ from google.cloud.sql.connector import Connector, DnsResolver import pymysql import sqlalchemy -# helper function to return SQLAlchemy connection pool -def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: - # function used to generate database connection - def getconn() -> pymysql.connections.Connection: - conn = connector.connect( +# initialize Cloud SQL Python Connector with `resolver=DnsResolver` +with Connector(resolver=DnsResolver) as connector: + # initialize SQLAlchemy connection pool with Connector + pool = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=lambda: connector.connect( "prod-db.mycompany.example.com", # using DNS name "pymysql", user="my-user", password="my-password", db="my-db-name" - ) - return conn - - # create connection pool - pool = sqlalchemy.create_engine( - "mysql+pymysql://", - creator=getconn, + ), ) - return pool - -# initialize Cloud SQL Python Connector with `resolver=DnsResolver` -with Connector(resolver=DnsResolver) as connector: - # initialize connection pool - pool = init_connection_pool(connector) # ... use SQLAlchemy engine normally ``` +### Automatic failover using DNS domain names + +> [!NOTE] +> +> Usage of the `asyncpg` driver does not currently support automatic failover. + +When the connector is configured using a domain name, the connector will +periodically check if the DNS record for an instance changes. When the connector +detects that the domain name refers to a different instance, the connector will +close all open connections to the old instance. Subsequent connection attempts +will be directed to the new instance. + +For example: suppose application is configured to connect using the +domain name `prod-db.mycompany.example.com`. Initially the private DNS +zone has a TXT record with the value `my-project:region:my-instance`. The +application establishes connections to the `my-project:region:my-instance` +Cloud SQL instance. + +Then, to reconfigure the application to use a different database +instance, change the value of the `prod-db.mycompany.example.com` DNS record +from `my-project:region:my-instance` to `my-project:other-region:my-instance-2` + +The connector inside the application detects the change to this +DNS record. Now, when the application connects to its database using the +domain name `prod-db.mycompany.example.com`, it will connect to the +`my-project:other-region:my-instance-2` Cloud SQL instance. + +The connector will automatically close all existing connections to +`my-project:region:my-instance`. This will force the connection pools to +establish new connections. Also, it may cause database queries in progress +to fail. + +The connector will poll for changes to the DNS name every 30 seconds by default. +You may configure the frequency of the connections using the Connector's +`failover_period` argument (i.e. `Connector(failover_period=60`). When this is +set to 0, the connector will disable polling and only check if the DNS record +changed when it is creating a new connection. + ### Using the Python Connector with Python Web Frameworks The Python Connector can be used alongside popular Python web frameworks such @@ -463,9 +473,12 @@ from google.cloud.sql.connector import Connector # initialize Python Connector object connector = Connector() -# Python Connector database connection function -def getconn(): - conn = connector.connect( +app = Flask(__name__) + +# configure Flask-SQLAlchemy to use Python Connector +app.config['SQLALCHEMY_DATABASE_URI'] = "postgresql+pg8000://" +app.config['SQLALCHEMY_ENGINE_OPTIONS'] = { + "creator": lambda: conn = connector.connect( "project:region:instance-name", # Cloud SQL Instance Connection Name "pg8000", user="my-user", @@ -473,15 +486,6 @@ def getconn(): db="my-database", ip_type="public" # "private" for private IP ) - return conn - - -app = Flask(__name__) - -# configure Flask-SQLAlchemy to use Python Connector -app.config['SQLALCHEMY_DATABASE_URI'] = "postgresql+pg8000://" -app.config['SQLALCHEMY_ENGINE_OPTIONS'] = { - "creator": getconn } # initialize the app with the extension @@ -502,38 +506,27 @@ your web application using [SQLAlchemy ORM](https://docs.sqlalchemy.org/en/14/or through the following: ```python -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine +import sqlalchemy from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from google.cloud.sql.connector import Connector -# helper function to return SQLAlchemy connection pool -def init_connection_pool(connector: Connector) -> Engine: - # Python Connector database connection function - def getconn(): - conn = connector.connect( - "project:region:instance-name", # Cloud SQL Instance Connection Name - "pg8000", - user="my-user", - password="my-password", - db="my-database", - ip_type="public" # "private" for private IP - ) - return conn - - SQLALCHEMY_DATABASE_URL = "postgresql+pg8000://" - - engine = create_engine( - SQLALCHEMY_DATABASE_URL , creator=getconn - ) - return engine # initialize Cloud SQL Python Connector connector = Connector() # create connection pool engine -engine = init_connection_pool(connector) +engine = sqlalchemy.create_engine( + "postgresql+pg8000://", + creator=lambda: connector.connect( + "project:region:instance-name", # Cloud SQL Instance Connection Name + "pg8000", + user="my-user", + password="my-password", + db="my-database", + ip_type="public" # "private" for private IP + ), +) # create SQLAlchemy ORM session SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -602,40 +595,29 @@ async def main(): #### SQLAlchemy Async Engine ```python -import asyncpg - import sqlalchemy from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from google.cloud.sql.connector import Connector, create_async_connector -async def init_connection_pool(connector: Connector) -> AsyncEngine: - # creation function to generate asyncpg connections as 'async_creator' arg - async def getconn() -> asyncpg.Connection: - conn: asyncpg.Connection = await connector.connect_async( + +async def main(): + # initialize Connector object for connections to Cloud SQL + connector = await create_async_connector() + + # The Cloud SQL Python Connector can be used along with SQLAlchemy using the + # 'async_creator' argument to 'create_async_engine' + pool = create_async_engine( + "postgresql+asyncpg://", + async_creator=lambda: connector.connect_async( "project:region:instance", # Cloud SQL instance connection name "asyncpg", user="my-user", password="my-password", db="my-db-name" # ... additional database driver args - ) - return conn - - # The Cloud SQL Python Connector can be used along with SQLAlchemy using the - # 'async_creator' argument to 'create_async_engine' - pool = create_async_engine( - "postgresql+asyncpg://", - async_creator=getconn, + ), ) - return pool - -async def main(): - # initialize Connector object for connections to Cloud SQL - connector = await create_async_connector() - - # initialize connection pool - pool = await init_connection_pool(connector) # example query async with pool.connect() as conn: @@ -706,33 +688,24 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from google.cloud.sql.connector import Connector -async def init_connection_pool(connector: Connector) -> AsyncEngine: - # creation function to generate asyncpg connections as 'async_creator' arg - async def getconn() -> asyncpg.Connection: - conn: asyncpg.Connection = await connector.connect_async( - "project:region:instance", # Cloud SQL instance connection name - "asyncpg", - user="my-user", - password="my-password", - db="my-db-name" - # ... additional database driver args - ) - return conn - - # The Cloud SQL Python Connector can be used along with SQLAlchemy using the - # 'async_creator' argument to 'create_async_engine' - pool = create_async_engine( - "postgresql+asyncpg://", - async_creator=getconn, - ) - return pool async def main(): # initialize Connector object for connections to Cloud SQL loop = asyncio.get_running_loop() async with Connector(loop=loop) as connector: - # initialize connection pool - pool = await init_connection_pool(connector) + # The Cloud SQL Python Connector can be used along with SQLAlchemy using the + # 'async_creator' argument to 'create_async_engine' + pool = create_async_engine( + "postgresql+asyncpg://", + async_creator=lambda: connector.connect_async( + "project:region:instance", # Cloud SQL instance connection name + "asyncpg", + user="my-user", + password="my-password", + db="my-db-name" + # ... additional database driver args + ), + ) # example query async with pool.connect() as conn: diff --git a/google/cloud/sql/connector/__init__.py b/google/cloud/sql/connector/__init__.py index 99a5097a2..6913337d3 100644 --- a/google/cloud/sql/connector/__init__.py +++ b/google/cloud/sql/connector/__init__.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2019 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 8a31eb9a0..556a01bde 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -156,10 +156,23 @@ async def _get_metadata( # resolve dnsName into IP address for PSC # Note that we have to check for PSC enablement also because CAS # instances also set the dnsName field. - # Remove trailing period from DNS name. Required for SSL in Python - dns_name = ret_dict.get("dnsName", "").rstrip(".") - if dns_name and ret_dict.get("pscEnabled"): - ip_addresses["PSC"] = dns_name + if ret_dict.get("pscEnabled"): + # Find PSC instance DNS name in the dns_names field + psc_dns_names = [ + d["name"] + for d in ret_dict.get("dnsNames", []) + if d["connectionType"] == "PRIVATE_SERVICE_CONNECT" + and d["dnsScope"] == "INSTANCE" + ] + dns_name = psc_dns_names[0] if psc_dns_names else None + + # Fall back do dns_name field if dns_names is not set + if dns_name is None: + dns_name = ret_dict.get("dnsName", None) + + # Remove trailing period from DNS name. Required for SSL in Python + if dns_name: + ip_addresses["PSC"] = dns_name.rstrip(".") return { "ip_addresses": ip_addresses, diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index 82e3a9018..c9e48935f 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -14,6 +14,7 @@ from __future__ import annotations +import abc from dataclasses import dataclass import logging import ssl @@ -34,6 +35,27 @@ logger = logging.getLogger(name=__name__) +class ConnectionInfoCache(abc.ABC): + """Abstract class for Connector connection info caches.""" + + @abc.abstractmethod + async def connect_info(self) -> ConnectionInfo: + pass + + @abc.abstractmethod + async def force_refresh(self) -> None: + pass + + @abc.abstractmethod + async def close(self) -> None: + pass + + @property + @abc.abstractmethod + def closed(self) -> bool: + pass + + @dataclass class ConnectionInfo: """Contains all necessary information to connect securely to the diff --git a/google/cloud/sql/connector/connection_name.py b/google/cloud/sql/connector/connection_name.py index 1bf711ab7..ad5dc40fb 100644 --- a/google/cloud/sql/connector/connection_name.py +++ b/google/cloud/sql/connector/connection_name.py @@ -19,6 +19,10 @@ # Additionally, we have to support legacy "domain-scoped" projects # (e.g. "google.com:PROJECT") CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)")) +# The domain name pattern in accordance with RFC 1035, RFC 1123 and RFC 2181. +DOMAIN_NAME_REGEX = re.compile( + r"^(?:[_a-z0-9](?:[_a-z0-9-]{0,61}[a-z0-9])?\.)+(?:[a-z](?:[a-z0-9-]{0,61}[a-z0-9])?)?$" +) @dataclass @@ -38,6 +42,16 @@ def __str__(self) -> str: return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}" return f"{self.project}:{self.region}:{self.instance_name}" + def get_connection_string(self) -> str: + """Get the instance connection string for the Cloud SQL instance.""" + return f"{self.project}:{self.region}:{self.instance_name}" + + +def _is_valid_domain(domain_name: str) -> bool: + if DOMAIN_NAME_REGEX.fullmatch(domain_name) is None: + return False + return True + def _parse_connection_name(connection_name: str) -> ConnectionName: return _parse_connection_name_with_domain_name(connection_name, "") diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 3e53e754a..c76092a40 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -20,9 +20,10 @@ from functools import partial import logging import os +import socket from threading import Thread from types import TracebackType -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import google.auth from google.auth.credentials import Credentials @@ -35,6 +36,7 @@ from google.cloud.sql.connector.enums import RefreshStrategy from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.lazy import LazyRefreshCache +from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds @@ -46,6 +48,7 @@ logger = logging.getLogger(name=__name__) ASYNC_DRIVERS = ["asyncpg"] +SERVER_PROXY_PORT = 3307 _DEFAULT_SCHEME = "https://" _DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" _SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}" @@ -67,6 +70,7 @@ def __init__( universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver, + failover_period: int = 30, ) -> None: """Initializes a Connector instance. @@ -114,6 +118,11 @@ def __init__( name. To resolve a DNS record to an instance connection name, use DnsResolver. Default: DefaultResolver + + failover_period (int): The time interval in seconds between each + attempt to check if a failover has occured for a given instance. + Must be used with `resolver=DnsResolver` to have any effect. + Default: 30 """ # if refresh_strategy is str, convert to RefreshStrategy enum if isinstance(refresh_strategy, str): @@ -143,9 +152,7 @@ def __init__( ) # initialize dict to store caches, key is a tuple consisting of instance # connection name string and enable_iam_auth boolean flag - self._cache: dict[ - tuple[str, bool], Union[RefreshAheadCache, LazyRefreshCache] - ] = {} + self._cache: dict[tuple[str, bool], MonitoredCache] = {} self._client: Optional[CloudSQLClient] = None # initialize credentials @@ -167,6 +174,7 @@ def __init__( self._enable_iam_auth = enable_iam_auth self._user_agent = user_agent self._resolver = resolver() + self._failover_period = failover_period # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) @@ -285,15 +293,19 @@ async def connect_async( driver=driver, ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) - if (instance_connection_string, enable_iam_auth) in self._cache: - cache = self._cache[(instance_connection_string, enable_iam_auth)] + + conn_name = await self._resolver.resolve(instance_connection_string) + # Cache entry must exist and not be closed + if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[ + (str(conn_name), enable_iam_auth) + ].closed: + monitored_cache = self._cache[(str(conn_name), enable_iam_auth)] else: - conn_name = await self._resolver.resolve(instance_connection_string) if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( f"['{conn_name}']: Refresh strategy is set to lazy refresh" ) - cache = LazyRefreshCache( + cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( conn_name, self._client, self._keys, @@ -309,8 +321,14 @@ async def connect_async( self._keys, enable_iam_auth, ) + # wrap cache as a MonitoredCache + monitored_cache = MonitoredCache( + cache, + self._failover_period, + self._resolver, + ) logger.debug(f"['{conn_name}']: Connection info added to cache") - self._cache[(instance_connection_string, enable_iam_auth)] = cache + self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache connect_func = { "pymysql": pymysql.connect, @@ -321,7 +339,7 @@ async def connect_async( # only accept supported database drivers try: - connector = connect_func[driver] + connector: Callable = connect_func[driver] # type: ignore except KeyError: raise KeyError(f"Driver '{driver}' is not supported.") @@ -339,14 +357,14 @@ async def connect_async( # attempt to get connection info for Cloud SQL instance try: - conn_info = await cache.connect_info() + conn_info = await monitored_cache.connect_info() # validate driver matches intended database engine DriverMapping.validate_engine(driver, conn_info.database_version) ip_address = conn_info.get_preferred_ip(ip_type) except Exception: # with an error from Cloud SQL Admin API call or IP type, invalidate # the cache and re-raise the error - await self._remove_cached(instance_connection_string, enable_iam_auth) + await self._remove_cached(str(conn_name), enable_iam_auth) raise logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") # format `user` param for automatic IAM database authn @@ -367,18 +385,28 @@ async def connect_async( await conn_info.create_ssl_context(enable_iam_auth), **kwargs, ) - # synchronous drivers are blocking and run using executor + # Create socket with SSLContext for sync drivers + ctx = await conn_info.create_ssl_context(enable_iam_auth) + sock = ctx.wrap_socket( + socket.create_connection((ip_address, SERVER_PROXY_PORT)), + server_hostname=ip_address, + ) + # If this connection was opened using a domain name, then store it + # for later in case we need to forcibly close it on failover. + if conn_info.conn_name.domain_name: + monitored_cache.sockets.append(sock) + # Synchronous drivers are blocking and run using executor connect_partial = partial( connector, ip_address, - await conn_info.create_ssl_context(enable_iam_auth), + sock, **kwargs, ) return await self._loop.run_in_executor(None, connect_partial) except Exception: # with any exception, we attempt a force refresh, then throw the error - await cache.force_refresh() + await monitored_cache.force_refresh() raise async def _remove_cached( @@ -456,6 +484,7 @@ async def create_async_connector( universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver, + failover_period: int = 30, ) -> Connector: """Helper function to create Connector object for asyncio connections. @@ -507,6 +536,11 @@ async def create_async_connector( DnsResolver. Default: DefaultResolver + failover_period (int): The time interval in seconds between each + attempt to check if a failover has occured for a given instance. + Must be used with `resolver=DnsResolver` to have any effect. + Default: 30 + Returns: A Connector instance configured with running event loop. """ @@ -525,4 +559,5 @@ async def create_async_connector( universe_domain=universe_domain, refresh_strategy=refresh_strategy, resolver=resolver, + failover_period=failover_period, ) diff --git a/google/cloud/sql/connector/exceptions.py b/google/cloud/sql/connector/exceptions.py index 92e3e5662..da39ea25d 100644 --- a/google/cloud/sql/connector/exceptions.py +++ b/google/cloud/sql/connector/exceptions.py @@ -77,3 +77,10 @@ class DnsResolutionError(Exception): Exception to be raised when an instance connection name can not be resolved from a DNS record. """ + + +class CacheClosedError(Exception): + """ + Exception to be raised when a ConnectionInfoCache can not be accessed after + it is closed. + """ diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 5df272fe2..fb8711309 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -24,6 +24,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_info import ConnectionInfoCache from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import RefreshNotValidError from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter @@ -35,7 +36,7 @@ APPLICATION_NAME = "cloud-sql-python-connector" -class RefreshAheadCache: +class RefreshAheadCache(ConnectionInfoCache): """Cache that refreshes connection info in the background prior to expiration. Background tasks are used to schedule refresh attempts to get a new @@ -74,6 +75,15 @@ def __init__( self._refresh_in_progress = asyncio.locks.Event() self._current: asyncio.Task = self._schedule_refresh(0) self._next: asyncio.Task = self._current + self._closed = False + + @property + def conn_name(self) -> ConnectionName: + return self._conn_name + + @property + def closed(self) -> bool: + return self._closed async def force_refresh(self) -> None: """ @@ -212,3 +222,4 @@ async def close(self) -> None: # gracefully wait for tasks to cancel tasks = asyncio.gather(self._current, self._next, return_exceptions=True) await asyncio.wait_for(tasks, timeout=2.0) + self._closed = True diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py index 1bc4f90f8..c75d07e52 100644 --- a/google/cloud/sql/connector/lazy.py +++ b/google/cloud/sql/connector/lazy.py @@ -21,13 +21,14 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_info import ConnectionInfoCache from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.refresh_utils import _refresh_buffer logger = logging.getLogger(name=__name__) -class LazyRefreshCache: +class LazyRefreshCache(ConnectionInfoCache): """Cache that refreshes connection info when a caller requests a connection. Only refreshes the cache when a new connection is requested and the current @@ -62,6 +63,15 @@ def __init__( self._lock = asyncio.Lock() self._cached: Optional[ConnectionInfo] = None self._needs_refresh = False + self._closed = False + + @property + def conn_name(self) -> ConnectionName: + return self._conn_name + + @property + def closed(self) -> bool: + return self._closed async def force_refresh(self) -> None: """ @@ -121,4 +131,5 @@ async def close(self) -> None: """Close is a no-op and provided purely for a consistent interface with other cache types. """ - pass + self._closed = True + return diff --git a/google/cloud/sql/connector/monitored_cache.py b/google/cloud/sql/connector/monitored_cache.py new file mode 100644 index 000000000..0c3fc4d03 --- /dev/null +++ b/google/cloud/sql/connector/monitored_cache.py @@ -0,0 +1,146 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import ssl +from typing import Any, Callable, Optional, Union + +from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_info import ConnectionInfoCache +from google.cloud.sql.connector.exceptions import CacheClosedError +from google.cloud.sql.connector.instance import RefreshAheadCache +from google.cloud.sql.connector.lazy import LazyRefreshCache +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver + +logger = logging.getLogger(name=__name__) + + +class MonitoredCache(ConnectionInfoCache): + def __init__( + self, + cache: Union[RefreshAheadCache, LazyRefreshCache], + failover_period: int, + resolver: Union[DefaultResolver, DnsResolver], + ) -> None: + self.resolver = resolver + self.cache = cache + self.domain_name_ticker: Optional[asyncio.Task] = None + self.sockets: list[ssl.SSLSocket] = [] + + # If domain name is configured for instance and failover period is set, + # poll for DNS record changes. + if self.cache.conn_name.domain_name and failover_period > 0: + self.domain_name_ticker = asyncio.create_task( + ticker(failover_period, self._check_domain_name) + ) + logger.debug( + f"['{self.cache.conn_name}']: Configured polling of domain " + f"name with failover period of {failover_period} seconds." + ) + + @property + def closed(self) -> bool: + return self.cache.closed + + def _purge_closed_sockets(self) -> None: + """Remove closed sockets from monitored cache. + + If a socket is closed by the database driver we should remove it from + list of sockets. + """ + open_sockets = [] + for socket in self.sockets: + # Check fileno for if socket is closed. Will return + # -1 on failure, which will be used to signal socket closed. + if socket.fileno() != -1: + open_sockets.append(socket) + self.sockets = open_sockets + + async def _check_domain_name(self) -> None: + # remove any closed connections from cache + self._purge_closed_sockets() + try: + # Resolve domain name and see if Cloud SQL instance connection name + # has changed. If it has, close all connections. + new_conn_name = await self.resolver.resolve( + self.cache.conn_name.domain_name + ) + if new_conn_name != self.cache.conn_name: + logger.debug( + f"['{self.cache.conn_name}']: Cloud SQL instance changed " + f"from {self.cache.conn_name.get_connection_string()} to " + f"{new_conn_name.get_connection_string()}, closing all " + "connections!" + ) + await self.close() + + except Exception as e: + # Domain name checks should not be fatal, log error and continue. + logger.debug( + f"['{self.cache.conn_name}']: Unable to check domain name, " + f"domain name {self.cache.conn_name.domain_name} did not " + f"resolve: {e}" + ) + + async def connect_info(self) -> ConnectionInfo: + if self.closed: + raise CacheClosedError( + "Can not get connection info, cache has already been closed." + ) + return await self.cache.connect_info() + + async def force_refresh(self) -> None: + # if cache is closed do not refresh + if self.closed: + return + return await self.cache.force_refresh() + + async def close(self) -> None: + # Cancel domain name ticker task. + if self.domain_name_ticker: + self.domain_name_ticker.cancel() + try: + await self.domain_name_ticker + except asyncio.CancelledError: + logger.debug( + f"['{self.cache.conn_name}']: Cancelled domain name polling task." + ) + finally: + self.domain_name_ticker = None + # If cache is already closed, no further work. + if self.closed: + return + + # Close underyling ConnectionInfoCache + await self.cache.close() + + # Close any still open sockets + for socket in self.sockets: + # Check fileno for if socket is closed. Will return + # -1 on failure, which will be used to signal socket closed. + if socket.fileno() != -1: + socket.close() + + +async def ticker(interval: int, function: Callable, *args: Any, **kwargs: Any) -> None: + """ + Ticker function to sleep for specified interval and then schedule call + to given function. + """ + while True: + # Sleep for interval and then schedule task + await asyncio.sleep(interval) + asyncio.create_task(function(*args, **kwargs)) diff --git a/google/cloud/sql/connector/pg8000.py b/google/cloud/sql/connector/pg8000.py index 1f66dde2a..baaee6615 100644 --- a/google/cloud/sql/connector/pg8000.py +++ b/google/cloud/sql/connector/pg8000.py @@ -14,18 +14,15 @@ limitations under the License. """ -import socket import ssl from typing import Any, TYPE_CHECKING -SERVER_PROXY_PORT = 3307 - if TYPE_CHECKING: import pg8000 def connect( - ip_address: str, ctx: ssl.SSLContext, **kwargs: Any + ip_address: str, sock: ssl.SSLSocket, **kwargs: Any ) -> "pg8000.dbapi.Connection": """Helper function to create a pg8000 DB-API connection object. @@ -33,8 +30,8 @@ def connect( :param ip_address: A string containing an IP address for the Cloud SQL instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA + :type sock: ssl.SSLSocket + :param sock: An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. @@ -48,12 +45,6 @@ def connect( 'Unable to import module "pg8000." Please install and try again.' ) - # Create socket and wrap with context. - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) - user = kwargs.pop("user") db = kwargs.pop("db") passwd = kwargs.pop("password", None) diff --git a/google/cloud/sql/connector/pymysql.py b/google/cloud/sql/connector/pymysql.py index a16584367..f83f7076c 100644 --- a/google/cloud/sql/connector/pymysql.py +++ b/google/cloud/sql/connector/pymysql.py @@ -14,18 +14,15 @@ limitations under the License. """ -import socket import ssl from typing import Any, TYPE_CHECKING -SERVER_PROXY_PORT = 3307 - if TYPE_CHECKING: import pymysql def connect( - ip_address: str, ctx: ssl.SSLContext, **kwargs: Any + ip_address: str, sock: ssl.SSLSocket, **kwargs: Any ) -> "pymysql.connections.Connection": """Helper function to create a pymysql DB-API connection object. @@ -33,8 +30,8 @@ def connect( :param ip_address: A string containing an IP address for the Cloud SQL instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA + :type sock: ssl.SSLSocket + :param sock: An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. :rtype: pymysql.Connection @@ -50,11 +47,6 @@ def connect( # allow automatic IAM database authentication to not require password kwargs["password"] = kwargs["password"] if "password" in kwargs else None - # Create socket and wrap with context. - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) # pop timeout as timeout arg is called 'connect_timeout' for pymysql timeout = kwargs.pop("timeout") kwargs["connect_timeout"] = kwargs.get("connect_timeout", timeout) diff --git a/google/cloud/sql/connector/pytds.py b/google/cloud/sql/connector/pytds.py index 243d90fd5..3128fdb6a 100644 --- a/google/cloud/sql/connector/pytds.py +++ b/google/cloud/sql/connector/pytds.py @@ -15,27 +15,24 @@ """ import platform -import socket import ssl from typing import Any, TYPE_CHECKING from google.cloud.sql.connector.exceptions import PlatformNotSupportedError -SERVER_PROXY_PORT = 3307 - if TYPE_CHECKING: import pytds -def connect(ip_address: str, ctx: ssl.SSLContext, **kwargs: Any) -> "pytds.Connection": +def connect(ip_address: str, sock: ssl.SSLSocket, **kwargs: Any) -> "pytds.Connection": """Helper function to create a pytds DB-API connection object. :type ip_address: str :param ip_address: A string containing an IP address for the Cloud SQL instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA + :type sock: ssl.SSLSocket + :param sock: An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. @@ -51,11 +48,6 @@ def connect(ip_address: str, ctx: ssl.SSLContext, **kwargs: Any) -> "pytds.Conne db = kwargs.pop("db", None) - # Create socket and wrap with context. - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) if kwargs.pop("active_directory_auth", False): if platform.system() == "Windows": # Ignore username and password if using active directory auth diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index 39efd0492..7d717ca05 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -17,6 +17,7 @@ from google.cloud.sql.connector.connection_name import ( _parse_connection_name_with_domain_name, ) +from google.cloud.sql.connector.connection_name import _is_valid_domain from google.cloud.sql.connector.connection_name import _parse_connection_name from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import DnsResolutionError @@ -40,8 +41,16 @@ async def resolve(self, dns: str) -> ConnectionName: # type: ignore conn_name = _parse_connection_name(dns) except ValueError: # The connection name was not project:region:instance format. - # Attempt to query a TXT record to get connection name. - conn_name = await self.query_dns(dns) + # Check if connection name is a valid DNS domain name + if _is_valid_domain(dns): + # Attempt to query a TXT record to get connection name. + conn_name = await self.query_dns(dns) + else: + raise ValueError( + "Arg `instance_connection_string` must have " + "format: PROJECT:REGION:INSTANCE or be a valid DNS domain " + f"name, got {dns}." + ) return conn_name async def query_dns(self, dns: str) -> ConnectionName: diff --git a/google/cloud/sql/connector/version.py b/google/cloud/sql/connector/version.py index 18c9772c7..f89ebde3c 100644 --- a/google/cloud/sql/connector/version.py +++ b/google/cloud/sql/connector/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.17.0" +__version__ = "1.18.0" diff --git a/pyproject.toml b/pyproject.toml index dec2ff489..8a694369b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,8 @@ build-backend = "setuptools.build_meta" description = "Google Cloud SQL Python Connector library" name = "cloud-sql-python-connector" authors = [{ name = "Google LLC", email = "googleapis-packages@google.com" }] -license = { text = "Apache 2.0" } +license = "Apache-2.0" +license-files = ["LICENSE"] requires-python = ">=3.9" readme = "README.md" classifiers = [ @@ -30,7 +31,6 @@ classifiers = [ # "Development Status :: 5 - Production/Stable" "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", diff --git a/requirements-test.txt b/requirements-test.txt index 52816e95c..7d276cbf2 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,8 +1,8 @@ -pytest==8.3.4 -mock==5.1.0 +pytest==8.3.5 +mock==5.2.0 pytest-cov==6.0.0 pytest-asyncio==0.25.3 -SQLAlchemy[asyncio]==2.0.38 +SQLAlchemy[asyncio]==2.0.39 sqlalchemy-pytds==1.0.2 sqlalchemy-stubs==0.4 PyMySQL==1.1.1 diff --git a/requirements.txt b/requirements.txt index fd04d2873..1dc6bc047 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ aiofiles==24.1.0 -aiohttp==3.11.12 -cryptography==44.0.1 +aiohttp==3.11.14 +cryptography==44.0.2 dnspython==2.7.0 Requests==2.32.3 google-auth==2.38.0 diff --git a/tests/conftest.py b/tests/conftest.py index 3a1a38a27..c75de48cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py index 8de14d576..2cc0716d6 100644 --- a/tests/system/test_asyncpg_connection.py +++ b/tests/system/test_asyncpg_connection.py @@ -16,13 +16,15 @@ import asyncio import os -from typing import Any +from typing import Any, Union import asyncpg import sqlalchemy import sqlalchemy.ext.asyncio from google.cloud.sql.connector import Connector +from google.cloud.sql.connector import DefaultResolver +from google.cloud.sql.connector import DnsResolver async def create_sqlalchemy_engine( @@ -31,6 +33,7 @@ async def create_sqlalchemy_engine( password: str, db: str, refresh_strategy: str = "background", + resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver, ) -> tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, Connector]: """Creates a connection pool for a Cloud SQL instance and returns the pool and the connector. Callers are responsible for closing the pool and the @@ -64,25 +67,30 @@ async def create_sqlalchemy_engine( Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid errors resulting from CPU being throttled. + resolver (Optional[google.cloud.sql.connector.DefaultResolver]): + Resolver class for resolving instance connection name. Use + google.cloud.sql.connector.DnsResolver when resolving DNS domain + names or google.cloud.sql.connector.DefaultResolver for regular + instance connection names ("my-project:my-region:my-instance"). """ loop = asyncio.get_running_loop() - connector = Connector(loop=loop, refresh_strategy=refresh_strategy) + connector = Connector( + loop=loop, refresh_strategy=refresh_strategy, resolver=resolver + ) - async def getconn() -> asyncpg.Connection: - conn: asyncpg.Connection = await connector.connect_async( + # create SQLAlchemy connection pool + engine = sqlalchemy.ext.asyncio.create_async_engine( + "postgresql+asyncpg://", + async_creator=lambda: connector.connect_async( instance_connection_name, "asyncpg", user=user, password=password, db=db, - ip_type="public", # can also be "private" or "psc" - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.ext.asyncio.create_async_engine( - "postgresql+asyncpg://", - async_creator=getconn, + ip_type=os.environ.get( + "IP_TYPE", "public" + ), # can be "public","private" or "psc" + ), execution_options={"isolation_level": "AUTOCOMMIT"}, ) return engine, connector @@ -139,7 +147,9 @@ async def getconn( user=user, password=password, db=db, - ip_type="public", # can also be "private" or "psc", + ip_type=os.environ.get( + "IP_TYPE", "public" + ), # can also be "private" or "psc", **kwargs, ) return conn @@ -183,6 +193,24 @@ async def test_lazy_sqlalchemy_connection_with_asyncpg() -> None: await connector.close_async() +async def test_custom_SAN_with_dns_sqlalchemy_connection_with_asyncpg() -> None: + """Basic test to get time from database.""" + inst_conn_name = os.environ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME"] + user = os.environ["POSTGRES_USER"] + password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"] + db = os.environ["POSTGRES_DB"] + + pool, connector = await create_sqlalchemy_engine( + inst_conn_name, user, password, db, resolver=DnsResolver + ) + + async with pool.connect() as conn: + res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone() + assert res[0] == 1 + + await connector.close_async() + + async def test_connection_with_asyncpg() -> None: """Basic test to get time from database.""" inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"] diff --git a/tests/system/test_asyncpg_iam_auth.py b/tests/system/test_asyncpg_iam_auth.py index 6e96d96bd..459b3a686 100644 --- a/tests/system/test_asyncpg_iam_auth.py +++ b/tests/system/test_asyncpg_iam_auth.py @@ -17,9 +17,9 @@ import asyncio import os -import asyncpg import sqlalchemy import sqlalchemy.ext.asyncio +import logging from google.cloud.sql.connector import Connector @@ -64,21 +64,19 @@ async def create_sqlalchemy_engine( loop = asyncio.get_running_loop() connector = Connector(loop=loop, refresh_strategy=refresh_strategy) - async def getconn() -> asyncpg.Connection: - conn: asyncpg.Connection = await connector.connect_async( + # create SQLAlchemy connection pool + engine = sqlalchemy.ext.asyncio.create_async_engine( + "postgresql+asyncpg://", + async_creator=lambda: connector.connect_async( instance_connection_name, "asyncpg", user=user, db=db, - ip_type="public", # can also be "private" or "psc" + ip_type=os.environ.get( + "IP_TYPE", "public" + ), # can be "public","private" or "psc" enable_iam_auth=True, - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.ext.asyncio.create_async_engine( - "postgresql+asyncpg://", - async_creator=getconn, + ), execution_options={"isolation_level": "AUTOCOMMIT"}, ) return engine, connector @@ -112,3 +110,7 @@ async def test_lazy_iam_authn_connection_with_asyncpg() -> None: assert res[0] == 1 await connector.close_async() + +logging.basicConfig(format="%(asctime)s [%(levelname)s]: %(message)s") +logger = logging.getLogger(name="google.cloud.sql.connector") +logger.setLevel(logging.DEBUG) diff --git a/tests/system/test_connector_object.py b/tests/system/test_connector_object.py index c2b5cf125..ffd8b5f1f 100644 --- a/tests/system/test_connector_object.py +++ b/tests/system/test_connector_object.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,7 +20,6 @@ import os from threading import Thread -import google.auth import pymysql import pytest import sqlalchemy @@ -50,20 +49,6 @@ def getconn() -> pymysql.connections.Connection: return pool -def test_connector_with_credentials() -> None: - """Test Connector object connection with credentials loaded from file.""" - credentials, _ = google.auth.load_credentials_from_file( - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] - ) - with Connector(credentials=credentials) as connector: - pool = init_connection_engine(connector) - - with pool.connect() as conn: - result = conn.execute(sqlalchemy.text("SELECT 1")).fetchone() - assert isinstance(result[0], int) - assert result[0] == 1 - - def test_multiple_connectors() -> None: """Test that same Cloud SQL instance can connect with two Connector objects.""" first_connector = Connector() diff --git a/tests/system/test_ip_types.py b/tests/system/test_ip_types.py index 2df3b1df5..3af49c54f 100644 --- a/tests/system/test_ip_types.py +++ b/tests/system/test_ip_types.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/system/test_pg8000_connection.py b/tests/system/test_pg8000_connection.py index b56a8e823..71582e9ed 100644 --- a/tests/system/test_pg8000_connection.py +++ b/tests/system/test_pg8000_connection.py @@ -18,10 +18,13 @@ import os # [START cloud_sql_connector_postgres_pg8000] -import pg8000 +from typing import Union + import sqlalchemy from google.cloud.sql.connector import Connector +from google.cloud.sql.connector import DefaultResolver +from google.cloud.sql.connector import DnsResolver def create_sqlalchemy_engine( @@ -30,6 +33,7 @@ def create_sqlalchemy_engine( password: str, db: str, refresh_strategy: str = "background", + resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver, ) -> tuple[sqlalchemy.engine.Engine, Connector]: """Creates a connection pool for a Cloud SQL instance and returns the pool and the connector. Callers are responsible for closing the pool and the @@ -64,24 +68,27 @@ def create_sqlalchemy_engine( Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid errors resulting from CPU being throttled. + resolver (Optional[google.cloud.sql.connector.DefaultResolver]): + Resolver class for resolving instance connection name. Use + google.cloud.sql.connector.DnsResolver when resolving DNS domain + names or google.cloud.sql.connector.DefaultResolver for regular + instance connection names ("my-project:my-region:my-instance"). """ - connector = Connector(refresh_strategy=refresh_strategy) + connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) - def getconn() -> pg8000.dbapi.Connection: - conn: pg8000.dbapi.Connection = connector.connect( + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "postgresql+pg8000://", + creator=lambda: connector.connect( instance_connection_name, "pg8000", user=user, password=password, db=db, - ip_type="public", # can also be "private" or "psc" - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.create_engine( - "postgresql+pg8000://", - creator=getconn, + ip_type=os.environ.get( + "IP_TYPE", "public" + ), # can be "public","private" or "psc" + ), ) return engine, connector @@ -153,3 +160,21 @@ def test_customer_managed_CAS_pg8000_connection() -> None: curr_time = time[0] assert type(curr_time) is datetime connector.close() + + +def test_custom_SAN_with_dns_pg8000_connection() -> None: + """Basic test to get time from database.""" + inst_conn_name = os.environ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME"] + user = os.environ["POSTGRES_USER"] + password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"] + db = os.environ["POSTGRES_DB"] + + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, password, db, resolver=DnsResolver + ) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() diff --git a/tests/system/test_pg8000_iam_auth.py b/tests/system/test_pg8000_iam_auth.py index 9a8607bcb..c5456823f 100644 --- a/tests/system/test_pg8000_iam_auth.py +++ b/tests/system/test_pg8000_iam_auth.py @@ -17,7 +17,6 @@ from datetime import datetime import os -import pg8000 import sqlalchemy from google.cloud.sql.connector import Connector @@ -63,21 +62,19 @@ def create_sqlalchemy_engine( """ connector = Connector(refresh_strategy=refresh_strategy) - def getconn() -> pg8000.dbapi.Connection: - conn: pg8000.dbapi.Connection = connector.connect( + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "postgresql+pg8000://", + creator=lambda: connector.connect( instance_connection_name, "pg8000", user=user, db=db, - ip_type="public", # can also be "private" or "psc" + ip_type=os.environ.get( + "IP_TYPE", "public" + ), # can be "public","private" or "psc" enable_iam_auth=True, - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.create_engine( - "postgresql+pg8000://", - creator=getconn, + ), ) return engine, connector diff --git a/tests/system/test_pymysql_connection.py b/tests/system/test_pymysql_connection.py index 490b1fab4..3eda9dac2 100644 --- a/tests/system/test_pymysql_connection.py +++ b/tests/system/test_pymysql_connection.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,7 +18,6 @@ import os # [START cloud_sql_connector_mysql_pymysql] -import pymysql import sqlalchemy from google.cloud.sql.connector import Connector @@ -67,21 +66,19 @@ def create_sqlalchemy_engine( """ connector = Connector(refresh_strategy=refresh_strategy) - def getconn() -> pymysql.Connection: - conn: pymysql.Connection = connector.connect( + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=lambda: connector.connect( instance_connection_name, "pymysql", user=user, password=password, db=db, - ip_type="public", # can also be "private" or "psc" - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.create_engine( - "mysql+pymysql://", - creator=getconn, + ip_type=os.environ.get( + "IP_TYPE", "public" + ), # can be "public","private" or "psc" + ), ) return engine, connector diff --git a/tests/system/test_pymysql_iam_auth.py b/tests/system/test_pymysql_iam_auth.py index 9a617b6f7..a4d5f2080 100644 --- a/tests/system/test_pymysql_iam_auth.py +++ b/tests/system/test_pymysql_iam_auth.py @@ -17,7 +17,6 @@ from datetime import datetime import os -import pymysql import sqlalchemy from google.cloud.sql.connector import Connector @@ -63,21 +62,19 @@ def create_sqlalchemy_engine( """ connector = Connector(refresh_strategy=refresh_strategy) - def getconn() -> pymysql.Connection: - conn: pymysql.Connection = connector.connect( + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=lambda: connector.connect( instance_connection_name, "pymysql", user=user, db=db, - ip_type="public", # can also be "private" or "psc" + ip_type=os.environ.get( + "IP_TYPE", "public" + ), # can be "public","private" or "psc" enable_iam_auth=True, - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.create_engine( - "mysql+pymysql://", - creator=getconn, + ), ) return engine, connector diff --git a/tests/system/test_pytds_connection.py b/tests/system/test_pytds_connection.py index d848abc18..896c34965 100644 --- a/tests/system/test_pytds_connection.py +++ b/tests/system/test_pytds_connection.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,7 +17,6 @@ import os # [START cloud_sql_connector_mysql_pytds] -import pytds import sqlalchemy from google.cloud.sql.connector import Connector @@ -65,21 +64,19 @@ def create_sqlalchemy_engine( """ connector = Connector(refresh_strategy=refresh_strategy) - def getconn() -> pytds.Connection: - conn: pytds.Connection = connector.connect( + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "mssql+pytds://", + creator=lambda: connector.connect( instance_connection_name, "pytds", user=user, password=password, db=db, - ip_type="public", # can also be "private" or "psc" - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.create_engine( - "mssql+pytds://", - creator=getconn, + ip_type=os.environ.get( + "IP_TYPE", "public" + ), # can be "public","private" or "psc" + ), ) return engine, connector diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 5d863677b..cd3299b7f 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -225,6 +225,7 @@ def __init__( "PRIMARY": "127.0.0.1", "PRIVATE": "10.0.0.1", }, + legacy_dns_name: bool = False, cert_before: datetime = datetime.datetime.now(datetime.timezone.utc), cert_expiration: datetime = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1), @@ -237,6 +238,7 @@ def __init__( self.psc_enabled = False self.cert_before = cert_before self.cert_expiration = cert_expiration + self.legacy_dns_name = legacy_dns_name # create self signed CA cert self.server_ca, self.server_key = generate_cert( self.project, self.name, cert_before, cert_expiration @@ -255,12 +257,22 @@ async def connect_settings(self, request: Any) -> web.Response: "instance": self.name, "expirationTime": str(self.cert_expiration), }, - "dnsName": "abcde.12345.us-central1.sql.goog", "pscEnabled": self.psc_enabled, "ipAddresses": ip_addrs, "region": self.region, "databaseVersion": self.db_version, } + if self.legacy_dns_name: + response["dnsName"] = "abcde.12345.us-central1.sql.goog" + else: + response["dnsNames"] = [ + { + "name": "abcde.12345.us-central1.sql.goog", + "connectionType": "PRIVATE_SERVICE_CONNECT", + "dnsScope": "INSTANCE", + } + ] + return web.Response(content_type="application/json", body=json.dumps(response)) async def generate_ephemeral(self, request: Any) -> web.Response: diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index af42af0ae..cfe509470 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -65,6 +65,28 @@ async def test_get_metadata_with_psc(fake_client: CloudSQLClient) -> None: assert isinstance(resp["server_ca_cert"], str) +@pytest.mark.asyncio +async def test_get_metadata_legacy_dns_with_psc(fake_client: CloudSQLClient) -> None: + """ + Test _get_metadata returns successfully with PSC IP type. + """ + # set PSC to enabled on test instance + fake_client.instance.psc_enabled = True + fake_client.instance.legacy_dns_name = True + resp = await fake_client._get_metadata( + "test-project", + "test-region", + "test-instance", + ) + assert resp["database_version"] == "POSTGRES_15" + assert resp["ip_addresses"] == { + "PRIMARY": "127.0.0.1", + "PRIVATE": "10.0.0.1", + "PSC": "abcde.12345.us-central1.sql.goog", + } + assert isinstance(resp["server_ca_cert"], str) + + @pytest.mark.asyncio async def test_get_ephemeral(fake_client: CloudSQLClient) -> None: """ diff --git a/tests/unit/test_connection_name.py b/tests/unit/test_connection_name.py index 783e14fe3..0861d8245 100644 --- a/tests/unit/test_connection_name.py +++ b/tests/unit/test_connection_name.py @@ -17,6 +17,7 @@ from google.cloud.sql.connector.connection_name import ( _parse_connection_name_with_domain_name, ) +from google.cloud.sql.connector.connection_name import _is_valid_domain from google.cloud.sql.connector.connection_name import _parse_connection_name from google.cloud.sql.connector.connection_name import ConnectionName @@ -30,6 +31,8 @@ def test_ConnectionName() -> None: assert conn_name.domain_name == "" # test ConnectionName str() method prints instance connection name assert str(conn_name) == "project:region:instance" + # test ConnectionName.get_connection_string + assert conn_name.get_connection_string() == "project:region:instance" def test_ConnectionName_with_domain_name() -> None: @@ -41,6 +44,8 @@ def test_ConnectionName_with_domain_name() -> None: assert conn_name.domain_name == "db.example.com" # test ConnectionName str() method prints with domain name assert str(conn_name) == "db.example.com -> project:region:instance" + # test ConnectionName.get_connection_string + assert conn_name.get_connection_string() == "project:region:instance" @pytest.mark.parametrize( @@ -96,3 +101,40 @@ def test_parse_connection_name_with_domain_name( assert expected == _parse_connection_name_with_domain_name( connection_name, domain_name ) + + +@pytest.mark.parametrize( + "domain_name, expected", + [ + ( + "prod-db.mycompany.example.com", + True, + ), + ( + "example.com.", # trailing dot + True, + ), + ( + "-example.com.", # leading hyphen + False, + ), + ( + "example", # missing TLD + False, + ), + ( + "127.0.0.1", # IPv4 address + False, + ), + ( + "0:0:0:0:0:0:0:1", # IPv6 address + False, + ), + ], +) +def test_is_valid_domain(domain_name: str, expected: bool) -> None: + """ + Test that _is_valid_domain works correctly for + parsing domain names. + """ + assert expected == _is_valid_domain(domain_name) diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index e25c9a384..498c947cc 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index aeedf3399..1a3d60917 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2019 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/unit/test_lazy.py b/tests/unit/test_lazy.py index 344b073e8..c6eef7509 100644 --- a/tests/unit/test_lazy.py +++ b/tests/unit/test_lazy.py @@ -21,6 +21,27 @@ from google.cloud.sql.connector.utils import generate_keys +async def test_LazyRefreshCache_properties(fake_client: CloudSQLClient) -> None: + """ + Test that LazyRefreshCache properties work as expected. + """ + keys = asyncio.create_task(generate_keys()) + conn_name = ConnectionName("test-project", "test-region", "test-instance") + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=keys, + enable_iam_auth=False, + ) + # test conn_name property + assert cache.conn_name == conn_name + # test closed property + assert cache.closed is False + # close cache and make sure property is updated + await cache.close() + assert cache.closed is True + + async def test_LazyRefreshCache_connect_info(fake_client: CloudSQLClient) -> None: """ Test that LazyRefreshCache.connect_info works as expected. diff --git a/tests/unit/test_monitored_cache.py b/tests/unit/test_monitored_cache.py new file mode 100644 index 000000000..1eea4eb46 --- /dev/null +++ b/tests/unit/test_monitored_cache.py @@ -0,0 +1,240 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import socket + +import dns.message +import dns.rdataclass +import dns.rdatatype +import dns.resolver +from mock import patch +from mocks import create_ssl_context +import pytest + +from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import CacheClosedError +from google.cloud.sql.connector.lazy import LazyRefreshCache +from google.cloud.sql.connector.monitored_cache import MonitoredCache +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver +from google.cloud.sql.connector.utils import generate_keys + +query_text = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +db.example.com. IN TXT +;ANSWER +db.example.com. 0 IN TXT "test-project:test-region:test-instance" +;AUTHORITY +;ADDITIONAL +""" + + +async def test_MonitoredCache_properties(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache properties work as expected. + """ + conn_name = ConnectionName("test-project", "test-region", "test-instance") + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + monitored_cache = MonitoredCache(cache, 30, DefaultResolver()) + # test that ticker is not set for instance not using domain name + assert monitored_cache.domain_name_ticker is None + # test closed property + assert monitored_cache.closed is False + # close cache and make sure property is updated + await monitored_cache.close() + assert monitored_cache.closed is True + + +async def test_MonitoredCache_CacheClosedError(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache.connect_info errors when cache is closed. + """ + conn_name = ConnectionName("test-project", "test-region", "test-instance") + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + monitored_cache = MonitoredCache(cache, 30, DefaultResolver()) + # test closed property + assert monitored_cache.closed is False + # close cache and make sure property is updated + await monitored_cache.close() + assert monitored_cache.closed is True + # attempt to get connect info from closed cache + with pytest.raises(CacheClosedError): + await monitored_cache.connect_info() + + +async def test_MonitoredCache_with_DnsResolver(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache with DnsResolver work as expected. + """ + conn_name = ConnectionName( + "test-project", "test-region", "test-instance", "db.example.com" + ) + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + # Patch DNS resolution with valid TXT records + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "db.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + monitored_cache = MonitoredCache(cache, 30, resolver) + # test that ticker is set for instance using domain name + assert type(monitored_cache.domain_name_ticker) is asyncio.Task + # test closed property + assert monitored_cache.closed is False + # close cache and make sure property is updated + await monitored_cache.close() + assert monitored_cache.closed is True + # domain name ticker should be set back to None + assert monitored_cache.domain_name_ticker is None + + +async def test_MonitoredCache_with_disabled_failover( + fake_client: CloudSQLClient, +) -> None: + """ + Test that MonitoredCache disables DNS polling with failover_period=0 + """ + conn_name = ConnectionName( + "test-project", "test-region", "test-instance", "db.example.com" + ) + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + monitored_cache = MonitoredCache(cache, 0, DnsResolver()) + # test that ticker is not set when failover is disabled + assert monitored_cache.domain_name_ticker is None + # test closed property + assert monitored_cache.closed is False + # close cache and make sure property is updated + await monitored_cache.close() + assert monitored_cache.closed is True + + +@pytest.mark.usefixtures("server") +async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache is closed when _check_domain_name has domain change. + """ + conn_name = ConnectionName( + "my-project", "my-region", "my-instance", "db.example.com" + ) + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + # Patch DNS resolution with valid TXT records + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "db.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + + # configure a local socket + ip_addr = "127.0.0.1" + context = await create_ssl_context() + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, + ) + # verify socket is open + assert sock.fileno() != -1 + # set failover to 0 to disable polling + monitored_cache = MonitoredCache(cache, 0, resolver) + # add socket to cache + monitored_cache.sockets = [sock] + # check cache is not closed + assert monitored_cache.closed is False + # call _check_domain_name and verify cache is closed + await monitored_cache._check_domain_name() + assert monitored_cache.closed is True + # verify socket was closed + assert sock.fileno() == -1 + + +@pytest.mark.usefixtures("server") +async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache._purge_closed_sockets removes closed sockets from + cache. + """ + conn_name = ConnectionName( + "my-project", "my-region", "my-instance", "db.example.com" + ) + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + # configure a local socket + ip_addr = "127.0.0.1" + context = await create_ssl_context() + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, + ) + + # set failover to 0 to disable polling + monitored_cache = MonitoredCache(cache, 0, DnsResolver()) + # verify socket is open + assert sock.fileno() != -1 + # add socket to cache + monitored_cache.sockets = [sock] + # call _purge_closed_sockets and verify socket remains + monitored_cache._purge_closed_sockets() + # verify socket is still open + assert sock.fileno() != -1 + assert len(monitored_cache.sockets) == 1 + # close socket + sock.close() + # call _purge_closed_sockets and verify socket is removed + monitored_cache._purge_closed_sockets() + assert len(monitored_cache.sockets) == 0 diff --git a/tests/unit/test_pg8000.py b/tests/unit/test_pg8000.py index 1b2adbb65..e01a53445 100644 --- a/tests/unit/test_pg8000.py +++ b/tests/unit/test_pg8000.py @@ -14,7 +14,7 @@ limitations under the License. """ -from functools import partial +import socket from typing import Any from mock import patch @@ -31,15 +31,14 @@ async def test_pg8000(kwargs: Any) -> None: ip_addr = "127.0.0.1" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) with patch("pg8000.dbapi.connect") as mock_connect: mock_connect.return_value = True - connection = connect(ip_addr, context, **kwargs) + connection = connect(ip_addr, sock, **kwargs) assert connection is True # verify that driver connection call would be made assert mock_connect.assert_called_once diff --git a/tests/unit/test_pymysql.py b/tests/unit/test_pymysql.py index 69d2aba8f..66b1f22a3 100644 --- a/tests/unit/test_pymysql.py +++ b/tests/unit/test_pymysql.py @@ -14,7 +14,7 @@ limitations under the License. """ -from functools import partial +import socket import ssl from typing import Any @@ -40,15 +40,14 @@ async def test_pymysql(kwargs: Any) -> None: ip_addr = "127.0.0.1" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) kwargs["timeout"] = 30 with patch("pymysql.Connection") as mock_connect: mock_connect.return_value = MockConnection - pymysql_connect(ip_addr, context, **kwargs) + pymysql_connect(ip_addr, sock, **kwargs) # verify that driver connection call would be made assert mock_connect.assert_called_once diff --git a/tests/unit/test_pytds.py b/tests/unit/test_pytds.py index 633aab74a..9efe00ee5 100644 --- a/tests/unit/test_pytds.py +++ b/tests/unit/test_pytds.py @@ -14,8 +14,8 @@ limitations under the License. """ -from functools import partial import platform +import socket from typing import Any from mock import patch @@ -43,16 +43,15 @@ async def test_pytds(kwargs: Any) -> None: ip_addr = "127.0.0.1" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) with patch("pytds.connect") as mock_connect: mock_connect.return_value = True - connection = connect(ip_addr, context, **kwargs) + connection = connect(ip_addr, sock, **kwargs) # verify that driver connection call would be made assert connection is True assert mock_connect.assert_called_once @@ -68,17 +67,16 @@ async def test_pytds_platform_error(kwargs: Any) -> None: assert platform.system() == "Linux" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) # add active_directory_auth to kwargs kwargs["active_directory_auth"] = True # verify that error is thrown with Linux and active_directory_auth with pytest.raises(PlatformNotSupportedError): - connect(ip_addr, context, **kwargs) + connect(ip_addr, sock, **kwargs) @pytest.mark.usefixtures("server") @@ -94,11 +92,10 @@ async def test_pytds_windows_active_directory_auth(kwargs: Any) -> None: assert platform.system() == "Windows" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) # add active_directory_auth and server_name to kwargs kwargs["active_directory_auth"] = True @@ -107,7 +104,7 @@ async def test_pytds_windows_active_directory_auth(kwargs: Any) -> None: mock_connect.return_value = True with patch("pytds.login.SspiAuth") as mock_login: mock_login.return_value = True - connection = connect(ip_addr, context, **kwargs) + connection = connect(ip_addr, sock, **kwargs) # verify that driver connection call would be made assert mock_login.assert_called_once assert connection is True diff --git a/tests/unit/test_rate_limiter.py b/tests/unit/test_rate_limiter.py index 5e187b81d..8ef586b58 100644 --- a/tests/unit/test_rate_limiter.py +++ b/tests/unit/test_rate_limiter.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 6545bc7a8..fe4e90955 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2019 Google LLC Licensed under the Apache License, Version 2.0 (the "License");