diff --git a/README.md b/README.md index 28553f972..1f0e633b9 100644 --- a/README.md +++ b/README.md @@ -365,6 +365,69 @@ conn = connector.connect( ) ``` +### Using DNS domain names to identify instances + +The connector can be configured to use DNS to look up an instance. This would +allow you to configure your application to connect to a database instance, and +centrally configure which instance in your DNS zone. + +#### Configure your DNS Records + +Add a DNS TXT record for the Cloud SQL instance to a **private** DNS server +or a private Google Cloud DNS Zone used by your application. + +> [!NOTE] +> +> You are strongly discouraged from adding DNS records for your +> Cloud SQL instances to a public DNS server. This would allow anyone on the +> internet to discover the Cloud SQL instance name. + +For example: suppose you wanted to use the domain name +`prod-db.mycompany.example.com` to connect to your database instance +`my-project:region:my-instance`. You would create the following DNS record: + +* Record type: `TXT` +* Name: `prod-db.mycompany.example.com` – This is the domain name used by the application +* Value: `my-project:my-region:my-instance` – This is the Cloud SQL instance connection name + +#### Configure the connector + +Configure the connector to resolve DNS names by initializing it with +`resolver=DnsResolver` and replacing the instance connection name with the DNS +name in `connector.connect`: + +```python +from google.cloud.sql.connector import Connector, DnsResolver +import pymysql +import sqlalchemy + +# helper function to return SQLAlchemy connection pool +def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: + # function used to generate database connection + def getconn() -> pymysql.connections.Connection: + conn = connector.connect( + "prod-db.mycompany.example.com", # using DNS name + "pymysql", + user="my-user", + password="my-password", + db="my-db-name" + ) + return conn + + # create connection pool + pool = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=getconn, + ) + return pool + +# initialize Cloud SQL Python Connector with `resolver=DnsResolver` +with Connector(resolver=DnsResolver) as connector: + # initialize connection pool + pool = init_connection_pool(connector) + # ... use SQLAlchemy engine normally +``` + ### Using the Python Connector with Python Web Frameworks The Python Connector can be used alongside popular Python web frameworks such diff --git a/google/cloud/sql/connector/__init__.py b/google/cloud/sql/connector/__init__.py index 5b06fcd7f..99a5097a2 100644 --- a/google/cloud/sql/connector/__init__.py +++ b/google/cloud/sql/connector/__init__.py @@ -18,12 +18,16 @@ from google.cloud.sql.connector.connector import create_async_connector from google.cloud.sql.connector.enums import IPTypes from google.cloud.sql.connector.enums import RefreshStrategy +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.version import __version__ __all__ = [ "__version__", "create_async_connector", "Connector", + "DefaultResolver", + "DnsResolver", "IPTypes", "RefreshStrategy", ] diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 7a89d7194..1e67373eb 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -37,6 +37,8 @@ import google.cloud.sql.connector.pg8000 as pg8000 import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.utils import format_database_user from google.cloud.sql.connector.utils import generate_keys @@ -63,6 +65,7 @@ def __init__( user_agent: Optional[str] = None, universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, + resolver: Type[DefaultResolver] | Type[DnsResolver] = DefaultResolver, ) -> None: """Initializes a Connector instance. @@ -104,6 +107,13 @@ def __init__( of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND + + resolver (DefaultResolver | DnsResolver): The class name of the + resolver to use for resolving the Cloud SQL instance connection + name. To resolve a DNS record to an instance connection name, use + DnsResolver. + Default: DefaultResolver + """ # if refresh_strategy is str, convert to RefreshStrategy enum if isinstance(refresh_strategy, str): @@ -157,6 +167,7 @@ def __init__( self._enable_iam_auth = enable_iam_auth self._quota_project = quota_project self._user_agent = user_agent + self._resolver = resolver() # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) @@ -269,13 +280,14 @@ async def connect_async( if (instance_connection_string, enable_iam_auth) in self._cache: cache = self._cache[(instance_connection_string, enable_iam_auth)] else: + conn_name = await self._resolver.resolve(instance_connection_string) if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( f"['{instance_connection_string}']: Refresh strategy is set" " to lazy refresh" ) cache = LazyRefreshCache( - instance_connection_string, + conn_name, self._client, self._keys, enable_iam_auth, @@ -286,7 +298,7 @@ async def connect_async( " to backgound refresh" ) cache = RefreshAheadCache( - instance_connection_string, + conn_name, self._client, self._keys, enable_iam_auth, diff --git a/google/cloud/sql/connector/exceptions.py b/google/cloud/sql/connector/exceptions.py index 7bff2300d..92e3e5662 100644 --- a/google/cloud/sql/connector/exceptions.py +++ b/google/cloud/sql/connector/exceptions.py @@ -70,3 +70,10 @@ class IncompatibleDriverError(Exception): Exception to be raised when the database driver given is for the wrong database engine. (i.e. asyncpg for a MySQL database) """ + + +class DnsResolutionError(Exception): + """ + Exception to be raised when an instance connection name can not be resolved + from a DNS record. + """ diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 9cf9bc787..3b0b9263d 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -24,7 +24,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import RefreshNotValidError from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter from google.cloud.sql.connector.refresh_utils import _is_valid @@ -45,7 +45,7 @@ class RefreshAheadCache: def __init__( self, - instance_connection_string: str, + conn_name: ConnectionName, client: CloudSQLClient, keys: asyncio.Future, enable_iam_auth: bool = False, @@ -53,8 +53,8 @@ def __init__( """Initializes a RefreshAheadCache instance. Args: - instance_connection_string (str): The Cloud SQL Instance's - connection string (also known as an instance connection name). + conn_name (ConnectionName): The Cloud SQL instance's + connection name. client (CloudSQLClient): The Cloud SQL Client instance. keys (asyncio.Future): A future to the client's public-private key pair. @@ -62,8 +62,6 @@ def __init__( (Postgres and MySQL) as the default authentication method for all connections. """ - # validate and parse instance connection name - conn_name = _parse_instance_connection_name(instance_connection_string) self._project, self._region, self._instance = ( conn_name.project, conn_name.region, diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py index 672f989e8..ab73785d1 100644 --- a/google/cloud/sql/connector/lazy.py +++ b/google/cloud/sql/connector/lazy.py @@ -21,7 +21,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.refresh_utils import _refresh_buffer logger = logging.getLogger(name=__name__) @@ -38,7 +38,7 @@ class LazyRefreshCache: def __init__( self, - instance_connection_string: str, + conn_name: ConnectionName, client: CloudSQLClient, keys: asyncio.Future, enable_iam_auth: bool = False, @@ -46,8 +46,8 @@ def __init__( """Initializes a LazyRefreshCache instance. Args: - instance_connection_string (str): The Cloud SQL Instance's - connection string (also known as an instance connection name). + conn_name (ConnectionName): The Cloud SQL instance's + connection name. client (CloudSQLClient): The Cloud SQL Client instance. keys (asyncio.Future): A future to the client's public-private key pair. @@ -55,8 +55,6 @@ def __init__( (Postgres and MySQL) as the default authentication method for all connections. """ - # validate and parse instance connection name - conn_name = _parse_instance_connection_name(instance_connection_string) self._project, self._region, self._instance = ( conn_name.project, conn_name.region, diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py new file mode 100644 index 000000000..15ccd6a21 --- /dev/null +++ b/google/cloud/sql/connector/resolver.py @@ -0,0 +1,67 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dns.asyncresolver + +from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import DnsResolutionError + + +class DefaultResolver: + """DefaultResolver simply validates and parses instance connection name.""" + + async def resolve(self, connection_name: str) -> ConnectionName: + return _parse_instance_connection_name(connection_name) + + +class DnsResolver(dns.asyncresolver.Resolver): + """ + DnsResolver resolves domain names into instance connection names using + TXT records in DNS. + """ + + async def resolve(self, dns: str) -> ConnectionName: # type: ignore + try: + conn_name = _parse_instance_connection_name(dns) + except ValueError: + # The connection name was not project:region:instance format. + # Attempt to query a TXT record to get connection name. + conn_name = await self.query_dns(dns) + return conn_name + + async def query_dns(self, dns: str) -> ConnectionName: + try: + # Attempt to query the TXT records. + records = await super().resolve(dns, "TXT", raise_on_no_answer=True) + # Sort the TXT record values alphabetically, strip quotes as record + # values can be returned as raw strings + rdata = [record.to_text().strip('"') for record in records] + rdata.sort() + # Attempt to parse records, returning the first valid record. + for record in rdata: + try: + conn_name = _parse_instance_connection_name(record) + return conn_name + except Exception: + continue + # If all records failed to parse, throw error + raise DnsResolutionError( + f"Unable to parse TXT record for `{dns}` -> `{rdata[0]}`" + ) + # Don't override above DnsResolutionError + except DnsResolutionError: + raise + except Exception as e: + raise DnsResolutionError(f"Unable to resolve TXT record for `{dns}`") from e diff --git a/requirements.txt b/requirements.txt index 466849d56..0ff0e41e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiofiles==24.1.0 aiohttp==3.11.9 cryptography==44.0.0 +dnspython==2.7.0 Requests==2.32.3 google-auth==2.36.0 diff --git a/setup.py b/setup.py index bb70449a5..79c6acf74 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "aiofiles", "aiohttp", "cryptography>=42.0.0", + "dnspython>=2.0.0", "Requests", "google-auth>=2.28.0", ] diff --git a/tests/conftest.py b/tests/conftest.py index 470fe19f4..3a1a38a27 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,7 @@ from unit.mocks import FakeCSQLInstance # type: ignore from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.utils import generate_keys @@ -144,7 +145,7 @@ async def fake_client( async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache, None]: keys = asyncio.create_task(generate_keys()) cache = RefreshAheadCache( - "test-project:test-region:test-instance", + ConnectionName("test-project", "test-region", "test-instance"), client=fake_client, keys=keys, ) diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index fd18f2d5e..d4f53ed51 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -26,6 +26,7 @@ from google.cloud.sql.connector import create_async_connector from google.cloud.sql.connector import IPTypes from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache @@ -322,18 +323,18 @@ async def test_Connector_remove_cached_bad_instance( async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: - conn_name = "bad-project:bad-region:bad-inst" + conn_name = ConnectionName("bad-project", "bad-region", "bad-inst") # populate cache cache = RefreshAheadCache(conn_name, fake_client, connector._keys) - connector._cache[(conn_name, False)] = cache + connector._cache[(str(conn_name), False)] = cache # aiohttp client should throw a 404 ClientResponseError with pytest.raises(ClientResponseError): await connector.connect_async( - conn_name, + str(conn_name), "pg8000", ) # check that cache has been removed from dict - assert (conn_name, False) not in connector._cache + assert (str(conn_name), False) not in connector._cache async def test_Connector_remove_cached_no_ip_type( @@ -348,21 +349,21 @@ async def test_Connector_remove_cached_no_ip_type( async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: - conn_name = "test-project:test-region:test-instance" + conn_name = ConnectionName("test-project", "test-region", "test-instance") # populate cache cache = RefreshAheadCache(conn_name, fake_client, connector._keys) - connector._cache[(conn_name, False)] = cache + connector._cache[(str(conn_name), False)] = cache # test instance does not have Private IP, thus should invalidate cache with pytest.raises(CloudSQLIPTypeError): await connector.connect_async( - conn_name, + str(conn_name), "pg8000", user="my-user", password="my-pass", ip_type="private", ) # check that cache has been removed from dict - assert (conn_name, False) not in connector._cache + assert (str(conn_name), False) not in connector._cache def test_default_universe_domain(fake_credentials: Credentials) -> None: diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 3ce0386b2..f80bb1494 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -24,6 +24,7 @@ from google.cloud.sql.connector import IPTypes from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.instance import RefreshAheadCache @@ -271,7 +272,7 @@ async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None # generate client key pair keys = asyncio.create_task(generate_keys()) cache = RefreshAheadCache( - "test-project:test-region:sqlserver-instance", + ConnectionName("test-project", "test-region", "sqlserver-instance"), client=fake_client, keys=keys, enable_iam_auth=True, diff --git a/tests/unit/test_lazy.py b/tests/unit/test_lazy.py index 27cd80b4f..344b073e8 100644 --- a/tests/unit/test_lazy.py +++ b/tests/unit/test_lazy.py @@ -16,6 +16,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.utils import generate_keys @@ -26,7 +27,7 @@ async def test_LazyRefreshCache_connect_info(fake_client: CloudSQLClient) -> Non """ keys = asyncio.create_task(generate_keys()) cache = LazyRefreshCache( - "test-project:test-region:test-instance", + ConnectionName("test-project", "test-region", "test-instance"), client=fake_client, keys=keys, enable_iam_auth=False, @@ -47,7 +48,7 @@ async def test_LazyRefreshCache_force_refresh(fake_client: CloudSQLClient) -> No """ keys = asyncio.create_task(generate_keys()) cache = LazyRefreshCache( - "test-project:test-region:test-instance", + ConnectionName("test-project", "test-region", "test-instance"), client=fake_client, keys=keys, enable_iam_auth=False, diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py new file mode 100644 index 000000000..d7404890a --- /dev/null +++ b/tests/unit/test_resolver.py @@ -0,0 +1,128 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dns.message +import dns.rdataclass +import dns.rdatatype +import dns.resolver +from mock import patch +import pytest + +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import DnsResolutionError +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver + +conn_str = "my-project:my-region:my-instance" +conn_name = ConnectionName("my-project", "my-region", "my-instance") + + +async def test_DefaultResolver() -> None: + """Test DefaultResolver just parses instance connection string.""" + resolver = DefaultResolver() + result = await resolver.resolve(conn_str) + assert result == conn_name + + +async def test_DnsResolver_with_conn_str() -> None: + """Test DnsResolver with instance connection name just parses connection string.""" + resolver = DnsResolver() + result = await resolver.resolve(conn_str) + assert result == conn_name + + +query_text = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +db.example.com. IN TXT +;ANSWER +db.example.com. 0 IN TXT "test-project:test-region:test-instance" +db.example.com. 0 IN TXT "my-project:my-region:my-instance" +;AUTHORITY +;ADDITIONAL +""" + + +async def test_DnsResolver_with_dns_name() -> None: + """Test DnsResolver resolves TXT record into proper instance connection name. + + Should sort valid TXT records alphabetically and take first one. + """ + # Patch DNS resolution with valid TXT records + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "db.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + # Resolution should return first value sorted alphabetically + result = await resolver.resolve("db.example.com") + assert result == conn_name + + +query_text_malformed = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +bad.example.com. IN TXT +;ANSWER +bad.example.com. 0 IN TXT "malformed-instance-name" +;AUTHORITY +;ADDITIONAL +""" + + +async def test_DnsResolver_with_malformed_txt() -> None: + """Test DnsResolver with TXT record that holds malformed instance connection name. + + Should throw DnsResolutionError + """ + # patch DNS resolution with malformed TXT record + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "bad.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text_malformed), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + with pytest.raises(DnsResolutionError) as exc_info: + await resolver.resolve("bad.example.com") + assert ( + exc_info.value.args[0] + == "Unable to parse TXT record for `bad.example.com` -> `malformed-instance-name`" + ) + + +async def test_DnsResolver_with_bad_dns_name() -> None: + """Test DnsResolver with bad dns name. + + Should throw DnsResolutionError + """ + resolver = DnsResolver() + resolver.port = 5053 + # set lifetime to 1 second for shorter timeout + resolver.lifetime = 1 + with pytest.raises(DnsResolutionError) as exc_info: + await resolver.resolve("bad.dns.com") + assert exc_info.value.args[0] == "Unable to resolve TXT record for `bad.dns.com`"