diff --git a/google/cloud/sql/connector/connection_name.py b/google/cloud/sql/connector/connection_name.py index d240fb565..1bf711ab7 100644 --- a/google/cloud/sql/connector/connection_name.py +++ b/google/cloud/sql/connector/connection_name.py @@ -31,12 +31,21 @@ class ConnectionName: project: str region: str instance_name: str + domain_name: str = "" def __str__(self) -> str: + if self.domain_name: + return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}" return f"{self.project}:{self.region}:{self.instance_name}" -def _parse_instance_connection_name(connection_name: str) -> ConnectionName: +def _parse_connection_name(connection_name: str) -> ConnectionName: + return _parse_connection_name_with_domain_name(connection_name, "") + + +def _parse_connection_name_with_domain_name( + connection_name: str, domain_name: str +) -> ConnectionName: if CONN_NAME_REGEX.fullmatch(connection_name) is None: raise ValueError( "Arg `instance_connection_string` must have " @@ -48,4 +57,5 @@ def _parse_instance_connection_name(connection_name: str) -> ConnectionName: connection_name_split[1], connection_name_split[3], connection_name_split[4], + domain_name, ) diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index 15ccd6a21..2cdcddbe2 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -14,7 +14,7 @@ import dns.asyncresolver -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import _parse_connection_name from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import DnsResolutionError @@ -23,7 +23,7 @@ class DefaultResolver: """DefaultResolver simply validates and parses instance connection name.""" async def resolve(self, connection_name: str) -> ConnectionName: - return _parse_instance_connection_name(connection_name) + return _parse_connection_name(connection_name) class DnsResolver(dns.asyncresolver.Resolver): @@ -34,7 +34,7 @@ class DnsResolver(dns.asyncresolver.Resolver): async def resolve(self, dns: str) -> ConnectionName: # type: ignore try: - conn_name = _parse_instance_connection_name(dns) + conn_name = _parse_connection_name(dns) except ValueError: # The connection name was not project:region:instance format. # Attempt to query a TXT record to get connection name. @@ -52,7 +52,7 @@ async def query_dns(self, dns: str) -> ConnectionName: # Attempt to parse records, returning the first valid record. for record in rdata: try: - conn_name = _parse_instance_connection_name(record) + conn_name = _parse_connection_name(record) return conn_name except Exception: continue diff --git a/tests/unit/test_connection_name.py b/tests/unit/test_connection_name.py index 1e3730424..a62f88d5f 100644 --- a/tests/unit/test_connection_name.py +++ b/tests/unit/test_connection_name.py @@ -14,9 +14,14 @@ import pytest # noqa F401 Needed to run the tests -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +# fmt: off +from google.cloud.sql.connector.connection_name import _parse_connection_name +from google.cloud.sql.connector.connection_name import \ + _parse_connection_name_with_domain_name from google.cloud.sql.connector.connection_name import ConnectionName +# fmt: on + def test_ConnectionName() -> None: conn_name = ConnectionName("project", "region", "instance") @@ -24,10 +29,22 @@ def test_ConnectionName() -> None: assert conn_name.project == "project" assert conn_name.region == "region" assert conn_name.instance_name == "instance" + assert conn_name.domain_name == "" # test ConnectionName str() method prints instance connection name assert str(conn_name) == "project:region:instance" +def test_ConnectionName_with_domain_name() -> None: + conn_name = ConnectionName("project", "region", "instance", "db.example.com") + # test class attributes are set properly + assert conn_name.project == "project" + assert conn_name.region == "region" + assert conn_name.instance_name == "instance" + assert conn_name.domain_name == "db.example.com" + # test ConnectionName str() method prints with domain name + assert str(conn_name) == "db.example.com -> project:region:instance" + + @pytest.mark.parametrize( "connection_name, expected", [ @@ -38,19 +55,46 @@ def test_ConnectionName() -> None: ), ], ) -def test_parse_instance_connection_name( - connection_name: str, expected: ConnectionName -) -> None: +def test_parse_connection_name(connection_name: str, expected: ConnectionName) -> None: """ - Test that _parse_instance_connection_name works correctly on + Test that _parse_connection_name works correctly on normal instance connection names and domain-scoped projects. """ - assert expected == _parse_instance_connection_name(connection_name) + assert expected == _parse_connection_name(connection_name) -def test_parse_instance_connection_name_bad_conn_name() -> None: +def test_parse_connection_name_bad_conn_name() -> None: """ Tests that ValueError is thrown for bad instance connection names. """ with pytest.raises(ValueError): - _parse_instance_connection_name("project:instance") # missing region + _parse_connection_name("project:instance") # missing region + + +@pytest.mark.parametrize( + "connection_name, domain_name, expected", + [ + ( + "project:region:instance", + "db.example.com", + ConnectionName("project", "region", "instance", "db.example.com"), + ), + ( + "domain-prefix:project:region:instance", + "db.example.com", + ConnectionName( + "domain-prefix:project", "region", "instance", "db.example.com" + ), + ), + ], +) +def test_parse_connection_name_with_domain_name( + connection_name: str, domain_name: str, expected: ConnectionName +) -> None: + """ + Test that _parse_connection_name_with_domain_name works correctly on + normal instance connection names and domain-scoped projects. + """ + assert expected == _parse_connection_name_with_domain_name( + connection_name, domain_name + )