Skip to content

Commit da8fea2

Browse files
chore: add domain_name to ConnectionName class
1 parent 1a8f274 commit da8fea2

File tree

3 files changed

+33
-13
lines changed

3 files changed

+33
-13
lines changed

google/cloud/sql/connector/connection_name.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,21 @@ class ConnectionName:
3131
project: str
3232
region: str
3333
instance_name: str
34+
domain_name: str = ""
3435

3536
def __str__(self) -> str:
37+
if self.domain_name:
38+
return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}"
3639
return f"{self.project}:{self.region}:{self.instance_name}"
3740

3841

39-
def _parse_instance_connection_name(connection_name: str) -> ConnectionName:
42+
def _parse_connection_name(connection_name: str) -> ConnectionName:
43+
return _parse_connection_name_with_domain_name(connection_name, "")
44+
45+
46+
def _parse_connection_name_with_domain_name(
47+
connection_name: str, domain_name: str
48+
) -> ConnectionName:
4049
if CONN_NAME_REGEX.fullmatch(connection_name) is None:
4150
raise ValueError(
4251
"Arg `instance_connection_string` must have "
@@ -48,4 +57,5 @@ def _parse_instance_connection_name(connection_name: str) -> ConnectionName:
4857
connection_name_split[1],
4958
connection_name_split[3],
5059
connection_name_split[4],
60+
domain_name,
5161
)

google/cloud/sql/connector/resolver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import dns.asyncresolver
1616

17-
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
17+
from google.cloud.sql.connector.connection_name import _parse_connection_name
1818
from google.cloud.sql.connector.connection_name import ConnectionName
1919
from google.cloud.sql.connector.exceptions import DnsResolutionError
2020

@@ -23,7 +23,7 @@ class DefaultResolver:
2323
"""DefaultResolver simply validates and parses instance connection name."""
2424

2525
async def resolve(self, connection_name: str) -> ConnectionName:
26-
return _parse_instance_connection_name(connection_name)
26+
return _parse_connection_name(connection_name)
2727

2828

2929
class DnsResolver(dns.asyncresolver.Resolver):
@@ -34,7 +34,7 @@ class DnsResolver(dns.asyncresolver.Resolver):
3434

3535
async def resolve(self, dns: str) -> ConnectionName: # type: ignore
3636
try:
37-
conn_name = _parse_instance_connection_name(dns)
37+
conn_name = _parse_connection_name(dns)
3838
except ValueError:
3939
# The connection name was not project:region:instance format.
4040
# Attempt to query a TXT record to get connection name.
@@ -52,7 +52,7 @@ async def query_dns(self, dns: str) -> ConnectionName:
5252
# Attempt to parse records, returning the first valid record.
5353
for record in rdata:
5454
try:
55-
conn_name = _parse_instance_connection_name(record)
55+
conn_name = _parse_connection_name(record)
5656
return conn_name
5757
except Exception:
5858
continue

tests/unit/test_connection_name.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import pytest # noqa F401 Needed to run the tests
1616

17-
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
17+
from google.cloud.sql.connector.connection_name import _parse_connection_name
1818
from google.cloud.sql.connector.connection_name import ConnectionName
1919

2020

@@ -24,10 +24,22 @@ def test_ConnectionName() -> None:
2424
assert conn_name.project == "project"
2525
assert conn_name.region == "region"
2626
assert conn_name.instance_name == "instance"
27+
assert conn_name.domain_name == ""
2728
# test ConnectionName str() method prints instance connection name
2829
assert str(conn_name) == "project:region:instance"
2930

3031

32+
def test_ConnectionName_with_domain_name() -> None:
33+
conn_name = ConnectionName("project", "region", "instance", "db.example.com")
34+
# test class attributes are set properly
35+
assert conn_name.project == "project"
36+
assert conn_name.region == "region"
37+
assert conn_name.instance_name == "instance"
38+
assert conn_name.domain_name == "db.example.com"
39+
# test ConnectionName str() method prints with domain name
40+
assert str(conn_name) == "db.example.com -> project:region:instance"
41+
42+
3143
@pytest.mark.parametrize(
3244
"connection_name, expected",
3345
[
@@ -38,19 +50,17 @@ def test_ConnectionName() -> None:
3850
),
3951
],
4052
)
41-
def test_parse_instance_connection_name(
42-
connection_name: str, expected: ConnectionName
43-
) -> None:
53+
def test_parse_connection_name(connection_name: str, expected: ConnectionName) -> None:
4454
"""
45-
Test that _parse_instance_connection_name works correctly on
55+
Test that _parse_connection_name works correctly on
4656
normal instance connection names and domain-scoped projects.
4757
"""
48-
assert expected == _parse_instance_connection_name(connection_name)
58+
assert expected == _parse_connection_name(connection_name)
4959

5060

51-
def test_parse_instance_connection_name_bad_conn_name() -> None:
61+
def test_parse_connection_name_bad_conn_name() -> None:
5262
"""
5363
Tests that ValueError is thrown for bad instance connection names.
5464
"""
5565
with pytest.raises(ValueError):
56-
_parse_instance_connection_name("project:instance") # missing region
66+
_parse_connection_name("project:instance") # missing region

0 commit comments

Comments
 (0)