Skip to content

Commit 8207ff2

Browse files
chore: add local dns server and tests
1 parent 5abcc08 commit 8207ff2

File tree

5 files changed

+67
-3
lines changed

5 files changed

+67
-3
lines changed

google/cloud/sql/connector/resolver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from dns.asyncresolver import Resolver
15+
import dns.asyncresolver
1616

1717
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
1818
from google.cloud.sql.connector.connection_name import ConnectionName
@@ -26,7 +26,7 @@ async def resolve(self, connection_name: str) -> ConnectionName:
2626
return _parse_instance_connection_name(connection_name)
2727

2828

29-
class DnsResolver(Resolver):
29+
class DnsResolver(dns.asyncresolver.Resolver):
3030
"""
3131
DnsResolver resolves domain names into instance connection names using
3232
TXT records in DNS.
@@ -58,7 +58,7 @@ async def query_dns(self, dns: str) -> ConnectionName:
5858
continue
5959
# If all records failed to parse, throw error
6060
raise DnsResolutionError(
61-
f"Unable to parse TXT record for `{dns}` -> {rdata[0]}"
61+
f"Unable to parse TXT record for `{dns}` -> `{rdata[0]}`"
6262
)
6363
# Don't override above DnsResolutionError
6464
except DnsResolutionError:

requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ asyncpg==0.30.0
1111
python-tds==1.16.0
1212
aioresponses==0.7.7
1313
pytest-aiohttp==1.0.5
14+
dnserver==0.4.0

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Any, AsyncGenerator, Generator
2222

2323
from aiohttp import web
24+
from dnserver import DNSServer
2425
import pytest # noqa F401 Needed to run the tests
2526
from unit.mocks import FakeCredentials # type: ignore
2627
from unit.mocks import FakeCSQLInstance # type: ignore
@@ -151,3 +152,13 @@ async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache
151152
)
152153
yield cache
153154
await cache.close()
155+
156+
157+
@pytest.fixture(autouse=True, scope="session")
158+
def dns_server() -> Generator:
159+
"""Setup local DNS server for tests with TXT records."""
160+
server = DNSServer.from_toml("tests/test_zones.toml", port=5053, upstream=None)
161+
server.start()
162+
assert server.is_running
163+
yield server
164+
server.stop()

tests/test_zones.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[[zones]]
2+
host = 'db.example.com'
3+
type = 'TXT'
4+
answer = "test-project:test-region:test-instance"
5+
6+
[[zones]]
7+
host = 'db.example.com'
8+
type = 'TXT'
9+
answer = "test-project2:test-region2:test-instance2"
10+
11+
[[zones]]
12+
host = 'bad.example.com'
13+
type = 'TXT'
14+
answer = "bad-instance-name"

tests/unit/test_resolver.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pytest
16+
1517
from google.cloud.sql.connector.connection_name import ConnectionName
18+
from google.cloud.sql.connector.exceptions import DnsResolutionError
1619
from google.cloud.sql.connector.resolver import DefaultResolver
1720
from google.cloud.sql.connector.resolver import DnsResolver
1821

@@ -32,3 +35,38 @@ async def test_DnsResolver_with_conn_str() -> None:
3235
resolver = DnsResolver()
3336
result = await resolver.resolve(conn_str)
3437
assert result == conn_name
38+
39+
40+
async def test_DnsResolver_with_dns_name() -> None:
41+
"""Test DnsResolver resolves TXT record into proper instance connection name."""
42+
resolver = DnsResolver()
43+
resolver.port = 5053
44+
result = await resolver.resolve(conn_str)
45+
assert result == conn_name
46+
47+
48+
async def test_DnsResolver_with_malformed_txt() -> None:
49+
"""Test DnsResolver with TXT record that holds malformed instance connection name.
50+
51+
Should throw DnsResolutionError
52+
"""
53+
resolver = DnsResolver()
54+
resolver.port = 5053
55+
with pytest.raises(DnsResolutionError) as exc_info:
56+
await resolver.resolve("bad.example.com")
57+
assert (
58+
exc_info.value.args[0]
59+
== "Unable to parse TXT record for `bad.example.com` -> `bad-instance-name`"
60+
)
61+
62+
63+
async def test_DnsResolver_with_bad_dns_name() -> None:
64+
"""Test DnsResolver with bad dns name.
65+
66+
Should throw DnsResolutionError
67+
"""
68+
resolver = DnsResolver()
69+
resolver.port = 5053
70+
with pytest.raises(DnsResolutionError) as exc_info:
71+
await resolver.resolve("bad.dns.com")
72+
assert exc_info.value.args[0] == "Unable to resolve TXT record for `bad.dns.com`"

0 commit comments

Comments
 (0)