Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion google/cloud/sql/connector/connection_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,21 @@ class ConnectionName:
project: str
region: str
instance_name: str
domain_name: str = ""

def __str__(self) -> str:
if self.domain_name:
return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}"
return f"{self.project}:{self.region}:{self.instance_name}"


def _parse_instance_connection_name(connection_name: str) -> ConnectionName:
def _parse_connection_name(connection_name: str) -> ConnectionName:
return _parse_connection_name_with_domain_name(connection_name, "")


def _parse_connection_name_with_domain_name(
connection_name: str, domain_name: str
) -> ConnectionName:
if CONN_NAME_REGEX.fullmatch(connection_name) is None:
raise ValueError(
"Arg `instance_connection_string` must have "
Expand All @@ -48,4 +57,5 @@ def _parse_instance_connection_name(connection_name: str) -> ConnectionName:
connection_name_split[1],
connection_name_split[3],
connection_name_split[4],
domain_name,
)
8 changes: 4 additions & 4 deletions google/cloud/sql/connector/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import dns.asyncresolver

from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
from google.cloud.sql.connector.connection_name import _parse_connection_name
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import DnsResolutionError

Expand All @@ -23,7 +23,7 @@ class DefaultResolver:
"""DefaultResolver simply validates and parses instance connection name."""

async def resolve(self, connection_name: str) -> ConnectionName:
return _parse_instance_connection_name(connection_name)
return _parse_connection_name(connection_name)


class DnsResolver(dns.asyncresolver.Resolver):
Expand All @@ -34,7 +34,7 @@ class DnsResolver(dns.asyncresolver.Resolver):

async def resolve(self, dns: str) -> ConnectionName: # type: ignore
try:
conn_name = _parse_instance_connection_name(dns)
conn_name = _parse_connection_name(dns)
except ValueError:
# The connection name was not project:region:instance format.
# Attempt to query a TXT record to get connection name.
Expand All @@ -52,7 +52,7 @@ async def query_dns(self, dns: str) -> ConnectionName:
# Attempt to parse records, returning the first valid record.
for record in rdata:
try:
conn_name = _parse_instance_connection_name(record)
conn_name = _parse_connection_name(record)
return conn_name
except Exception:
continue
Expand Down
60 changes: 52 additions & 8 deletions tests/unit/test_connection_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,37 @@

import pytest # noqa F401 Needed to run the tests

from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
# fmt: off
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import style of long import names differs between isort and black format, turning black formatter off for these imports to stop lint job from complaining.

from google.cloud.sql.connector.connection_name import _parse_connection_name
from google.cloud.sql.connector.connection_name import \
_parse_connection_name_with_domain_name
from google.cloud.sql.connector.connection_name import ConnectionName

# fmt: on


def test_ConnectionName() -> None:
conn_name = ConnectionName("project", "region", "instance")
# test class attributes are set properly
assert conn_name.project == "project"
assert conn_name.region == "region"
assert conn_name.instance_name == "instance"
assert conn_name.domain_name == ""
# test ConnectionName str() method prints instance connection name
assert str(conn_name) == "project:region:instance"


def test_ConnectionName_with_domain_name() -> None:
conn_name = ConnectionName("project", "region", "instance", "db.example.com")
# test class attributes are set properly
assert conn_name.project == "project"
assert conn_name.region == "region"
assert conn_name.instance_name == "instance"
assert conn_name.domain_name == "db.example.com"
# test ConnectionName str() method prints with domain name
assert str(conn_name) == "db.example.com -> project:region:instance"


@pytest.mark.parametrize(
"connection_name, expected",
[
Expand All @@ -38,19 +55,46 @@ def test_ConnectionName() -> None:
),
],
)
def test_parse_instance_connection_name(
connection_name: str, expected: ConnectionName
) -> None:
def test_parse_connection_name(connection_name: str, expected: ConnectionName) -> None:
"""
Test that _parse_instance_connection_name works correctly on
Test that _parse_connection_name works correctly on
normal instance connection names and domain-scoped projects.
"""
assert expected == _parse_instance_connection_name(connection_name)
assert expected == _parse_connection_name(connection_name)


def test_parse_instance_connection_name_bad_conn_name() -> None:
def test_parse_connection_name_bad_conn_name() -> None:
"""
Tests that ValueError is thrown for bad instance connection names.
"""
with pytest.raises(ValueError):
_parse_instance_connection_name("project:instance") # missing region
_parse_connection_name("project:instance") # missing region


@pytest.mark.parametrize(
"connection_name, domain_name, expected",
[
(
"project:region:instance",
"db.example.com",
ConnectionName("project", "region", "instance", "db.example.com"),
),
(
"domain-prefix:project:region:instance",
"db.example.com",
ConnectionName(
"domain-prefix:project", "region", "instance", "db.example.com"
),
),
],
)
def test_parse_connection_name_with_domain_name(
connection_name: str, domain_name: str, expected: ConnectionName
) -> None:
"""
Test that _parse_connection_name_with_domain_name works correctly on
normal instance connection names and domain-scoped projects.
"""
assert expected == _parse_connection_name_with_domain_name(
connection_name, domain_name
)
Loading