diff --git a/google/cloud/sql/connector/connection_name.py b/google/cloud/sql/connector/connection_name.py index 1bf711ab7..437fd6607 100644 --- a/google/cloud/sql/connector/connection_name.py +++ b/google/cloud/sql/connector/connection_name.py @@ -19,6 +19,10 @@ # Additionally, we have to support legacy "domain-scoped" projects # (e.g. "google.com:PROJECT") CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)")) +# The domain name pattern in accordance with RFC 1035, RFC 1123 and RFC 2181. +DOMAIN_NAME_REGEX = re.compile( + r"^(?:[_a-z0-9](?:[_a-z0-9-]{0,61}[a-z0-9])?\.)+(?:[a-z](?:[a-z0-9-]{0,61}[a-z0-9])?)?$" +) @dataclass @@ -39,6 +43,12 @@ def __str__(self) -> str: return f"{self.project}:{self.region}:{self.instance_name}" +def _is_valid_domain(domain_name: str) -> bool: + if DOMAIN_NAME_REGEX.fullmatch(domain_name) is None: + return False + return True + + def _parse_connection_name(connection_name: str) -> ConnectionName: return _parse_connection_name_with_domain_name(connection_name, "") diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index 39efd0492..7d717ca05 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -17,6 +17,7 @@ from google.cloud.sql.connector.connection_name import ( _parse_connection_name_with_domain_name, ) +from google.cloud.sql.connector.connection_name import _is_valid_domain from google.cloud.sql.connector.connection_name import _parse_connection_name from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import DnsResolutionError @@ -40,8 +41,16 @@ async def resolve(self, dns: str) -> ConnectionName: # type: ignore conn_name = _parse_connection_name(dns) except ValueError: # The connection name was not project:region:instance format. - # Attempt to query a TXT record to get connection name. - conn_name = await self.query_dns(dns) + # Check if connection name is a valid DNS domain name + if _is_valid_domain(dns): + # Attempt to query a TXT record to get connection name. + conn_name = await self.query_dns(dns) + else: + raise ValueError( + "Arg `instance_connection_string` must have " + "format: PROJECT:REGION:INSTANCE or be a valid DNS domain " + f"name, got {dns}." + ) return conn_name async def query_dns(self, dns: str) -> ConnectionName: diff --git a/tests/unit/test_connection_name.py b/tests/unit/test_connection_name.py index 783e14fe3..218034d51 100644 --- a/tests/unit/test_connection_name.py +++ b/tests/unit/test_connection_name.py @@ -17,6 +17,7 @@ from google.cloud.sql.connector.connection_name import ( _parse_connection_name_with_domain_name, ) +from google.cloud.sql.connector.connection_name import _is_valid_domain from google.cloud.sql.connector.connection_name import _parse_connection_name from google.cloud.sql.connector.connection_name import ConnectionName @@ -96,3 +97,40 @@ def test_parse_connection_name_with_domain_name( assert expected == _parse_connection_name_with_domain_name( connection_name, domain_name ) + + +@pytest.mark.parametrize( + "domain_name, expected", + [ + ( + "prod-db.mycompany.example.com", + True, + ), + ( + "example.com.", # trailing dot + True, + ), + ( + "-example.com.", # leading hyphen + False, + ), + ( + "example", # missing TLD + False, + ), + ( + "127.0.0.1", # IPv4 address + False, + ), + ( + "0:0:0:0:0:0:0:1", # IPv6 address + False, + ), + ], +) +def test_is_valid_domain(domain_name: str, expected: bool) -> None: + """ + Test that _is_valid_domain works correctly for + parsing domain names. + """ + assert expected == _is_valid_domain(domain_name)