Skip to content

Commit 720d1ad

Browse files
refactor: add domain_name to ConnectionName class (#1209)
1 parent 1a8f274 commit 720d1ad

File tree

3 files changed

+67
-13
lines changed

3 files changed

+67
-13
lines changed

google/cloud/sql/connector/connection_name.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,21 @@ class ConnectionName:
3131
project: str
3232
region: str
3333
instance_name: str
34+
domain_name: str = ""
3435

3536
def __str__(self) -> str:
37+
if self.domain_name:
38+
return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}"
3639
return f"{self.project}:{self.region}:{self.instance_name}"
3740

3841

39-
def _parse_instance_connection_name(connection_name: str) -> ConnectionName:
42+
def _parse_connection_name(connection_name: str) -> ConnectionName:
43+
return _parse_connection_name_with_domain_name(connection_name, "")
44+
45+
46+
def _parse_connection_name_with_domain_name(
47+
connection_name: str, domain_name: str
48+
) -> ConnectionName:
4049
if CONN_NAME_REGEX.fullmatch(connection_name) is None:
4150
raise ValueError(
4251
"Arg `instance_connection_string` must have "
@@ -48,4 +57,5 @@ def _parse_instance_connection_name(connection_name: str) -> ConnectionName:
4857
connection_name_split[1],
4958
connection_name_split[3],
5059
connection_name_split[4],
60+
domain_name,
5161
)

google/cloud/sql/connector/resolver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import dns.asyncresolver
1616

17-
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
17+
from google.cloud.sql.connector.connection_name import _parse_connection_name
1818
from google.cloud.sql.connector.connection_name import ConnectionName
1919
from google.cloud.sql.connector.exceptions import DnsResolutionError
2020

@@ -23,7 +23,7 @@ class DefaultResolver:
2323
"""DefaultResolver simply validates and parses instance connection name."""
2424

2525
async def resolve(self, connection_name: str) -> ConnectionName:
26-
return _parse_instance_connection_name(connection_name)
26+
return _parse_connection_name(connection_name)
2727

2828

2929
class DnsResolver(dns.asyncresolver.Resolver):
@@ -34,7 +34,7 @@ class DnsResolver(dns.asyncresolver.Resolver):
3434

3535
async def resolve(self, dns: str) -> ConnectionName: # type: ignore
3636
try:
37-
conn_name = _parse_instance_connection_name(dns)
37+
conn_name = _parse_connection_name(dns)
3838
except ValueError:
3939
# The connection name was not project:region:instance format.
4040
# Attempt to query a TXT record to get connection name.
@@ -52,7 +52,7 @@ async def query_dns(self, dns: str) -> ConnectionName:
5252
# Attempt to parse records, returning the first valid record.
5353
for record in rdata:
5454
try:
55-
conn_name = _parse_instance_connection_name(record)
55+
conn_name = _parse_connection_name(record)
5656
return conn_name
5757
except Exception:
5858
continue

tests/unit/test_connection_name.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,37 @@
1414

1515
import pytest # noqa F401 Needed to run the tests
1616

17-
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
17+
# fmt: off
18+
from google.cloud.sql.connector.connection_name import _parse_connection_name
19+
from google.cloud.sql.connector.connection_name import \
20+
_parse_connection_name_with_domain_name
1821
from google.cloud.sql.connector.connection_name import ConnectionName
1922

23+
# fmt: on
24+
2025

2126
def test_ConnectionName() -> None:
2227
conn_name = ConnectionName("project", "region", "instance")
2328
# test class attributes are set properly
2429
assert conn_name.project == "project"
2530
assert conn_name.region == "region"
2631
assert conn_name.instance_name == "instance"
32+
assert conn_name.domain_name == ""
2733
# test ConnectionName str() method prints instance connection name
2834
assert str(conn_name) == "project:region:instance"
2935

3036

37+
def test_ConnectionName_with_domain_name() -> None:
38+
conn_name = ConnectionName("project", "region", "instance", "db.example.com")
39+
# test class attributes are set properly
40+
assert conn_name.project == "project"
41+
assert conn_name.region == "region"
42+
assert conn_name.instance_name == "instance"
43+
assert conn_name.domain_name == "db.example.com"
44+
# test ConnectionName str() method prints with domain name
45+
assert str(conn_name) == "db.example.com -> project:region:instance"
46+
47+
3148
@pytest.mark.parametrize(
3249
"connection_name, expected",
3350
[
@@ -38,19 +55,46 @@ def test_ConnectionName() -> None:
3855
),
3956
],
4057
)
41-
def test_parse_instance_connection_name(
42-
connection_name: str, expected: ConnectionName
43-
) -> None:
58+
def test_parse_connection_name(connection_name: str, expected: ConnectionName) -> None:
4459
"""
45-
Test that _parse_instance_connection_name works correctly on
60+
Test that _parse_connection_name works correctly on
4661
normal instance connection names and domain-scoped projects.
4762
"""
48-
assert expected == _parse_instance_connection_name(connection_name)
63+
assert expected == _parse_connection_name(connection_name)
4964

5065

51-
def test_parse_instance_connection_name_bad_conn_name() -> None:
66+
def test_parse_connection_name_bad_conn_name() -> None:
5267
"""
5368
Tests that ValueError is thrown for bad instance connection names.
5469
"""
5570
with pytest.raises(ValueError):
56-
_parse_instance_connection_name("project:instance") # missing region
71+
_parse_connection_name("project:instance") # missing region
72+
73+
74+
@pytest.mark.parametrize(
75+
"connection_name, domain_name, expected",
76+
[
77+
(
78+
"project:region:instance",
79+
"db.example.com",
80+
ConnectionName("project", "region", "instance", "db.example.com"),
81+
),
82+
(
83+
"domain-prefix:project:region:instance",
84+
"db.example.com",
85+
ConnectionName(
86+
"domain-prefix:project", "region", "instance", "db.example.com"
87+
),
88+
),
89+
],
90+
)
91+
def test_parse_connection_name_with_domain_name(
92+
connection_name: str, domain_name: str, expected: ConnectionName
93+
) -> None:
94+
"""
95+
Test that _parse_connection_name_with_domain_name works correctly on
96+
normal instance connection names and domain-scoped projects.
97+
"""
98+
assert expected == _parse_connection_name_with_domain_name(
99+
connection_name, domain_name
100+
)

0 commit comments

Comments
 (0)