Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions google/cloud/sql/connector/connection_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
# Additionally, we have to support legacy "domain-scoped" projects
# (e.g. "google.com:PROJECT")
CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)"))
# The domain name pattern in accordance with RFC 1035, RFC 1123 and RFC 2181.
DOMAIN_NAME_REGEX = re.compile(
r"^(?:[_a-z0-9](?:[_a-z0-9-]{0,61}[a-z0-9])?\.)+(?:[a-z](?:[a-z0-9-]{0,61}[a-z0-9])?)?$"
)


@dataclass
Expand All @@ -39,6 +43,12 @@ def __str__(self) -> str:
return f"{self.project}:{self.region}:{self.instance_name}"


def _is_valid_domain(domain_name: str) -> bool:
if DOMAIN_NAME_REGEX.fullmatch(domain_name) is None:
return False
return True


def _parse_connection_name(connection_name: str) -> ConnectionName:
return _parse_connection_name_with_domain_name(connection_name, "")

Expand Down
13 changes: 11 additions & 2 deletions google/cloud/sql/connector/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from google.cloud.sql.connector.connection_name import (
_parse_connection_name_with_domain_name,
)
from google.cloud.sql.connector.connection_name import _is_valid_domain
from google.cloud.sql.connector.connection_name import _parse_connection_name
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import DnsResolutionError
Expand All @@ -40,8 +41,16 @@ async def resolve(self, dns: str) -> ConnectionName: # type: ignore
conn_name = _parse_connection_name(dns)
except ValueError:
# The connection name was not project:region:instance format.
# Attempt to query a TXT record to get connection name.
conn_name = await self.query_dns(dns)
# Check if connection name is a valid DNS domain name
if _is_valid_domain(dns):
# Attempt to query a TXT record to get connection name.
conn_name = await self.query_dns(dns)
else:
raise ValueError(
"Arg `instance_connection_string` must have "
"format: PROJECT:REGION:INSTANCE or be a valid DNS domain "
f"name, got {dns}."
)
return conn_name

async def query_dns(self, dns: str) -> ConnectionName:
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_connection_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from google.cloud.sql.connector.connection_name import (
_parse_connection_name_with_domain_name,
)
from google.cloud.sql.connector.connection_name import _is_valid_domain
from google.cloud.sql.connector.connection_name import _parse_connection_name
from google.cloud.sql.connector.connection_name import ConnectionName

Expand Down Expand Up @@ -96,3 +97,40 @@ def test_parse_connection_name_with_domain_name(
assert expected == _parse_connection_name_with_domain_name(
connection_name, domain_name
)


@pytest.mark.parametrize(
"domain_name, expected",
[
(
"prod-db.mycompany.example.com",
True,
),
(
"example.com.", # trailing dot
True,
),
(
"-example.com.", # leading hyphen
False,
),
(
"example", # missing TLD
False,
),
(
"127.0.0.1", # IPv4 address
False,
),
(
"0:0:0:0:0:0:0:1", # IPv6 address
False,
),
],
)
def test_is_valid_domain(domain_name: str, expected: bool) -> None:
"""
Test that _is_valid_domain works correctly for
parsing domain names.
"""
assert expected == _is_valid_domain(domain_name)
Loading