Skip to content

Commit d260934

Browse files
chore: improve tests
1 parent a101003 commit d260934

File tree

2 files changed

+234
-0
lines changed

2 files changed

+234
-0
lines changed

google/cloud/sql/connector/monitored_cache.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,23 @@ def __init__(
5454
def closed(self) -> bool:
5555
return self.cache.closed
5656

57+
def _purge_closed_sockets(self) -> None:
58+
"""Remove closed sockets from monitored cache.
59+
60+
If a socket is closed by the database driver we should remove it from
61+
list of sockets.
62+
"""
63+
open_sockets = []
64+
for socket in self.sockets:
65+
# Check fileno for if socket is closed. Will return
66+
# -1 on failure, which will be used to signal socket closed.
67+
if socket.fileno() != -1:
68+
open_sockets.append(socket)
69+
self.sockets = open_sockets
70+
5771
async def _check_domain_name(self) -> None:
72+
# remove any closed connections from cache
73+
self._purge_closed_sockets()
5874
try:
5975
# Resolve domain name and see if Cloud SQL instance connection name
6076
# has changed. If it has, close all connections.

tests/unit/test_monitored_cache.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import socket
17+
18+
import dns.message
19+
import dns.rdataclass
20+
import dns.rdatatype
21+
import dns.resolver
22+
from mock import patch
23+
from mocks import create_ssl_context
24+
import pytest
25+
26+
from google.cloud.sql.connector.client import CloudSQLClient
27+
from google.cloud.sql.connector.connection_name import ConnectionName
28+
from google.cloud.sql.connector.lazy import LazyRefreshCache
29+
from google.cloud.sql.connector.monitored_cache import MonitoredCache
30+
from google.cloud.sql.connector.resolver import DefaultResolver
31+
from google.cloud.sql.connector.resolver import DnsResolver
32+
from google.cloud.sql.connector.utils import generate_keys
33+
34+
query_text = """id 1234
35+
opcode QUERY
36+
rcode NOERROR
37+
flags QR AA RD RA
38+
;QUESTION
39+
db.example.com. IN TXT
40+
;ANSWER
41+
db.example.com. 0 IN TXT "test-project:test-region:test-instance"
42+
;AUTHORITY
43+
;ADDITIONAL
44+
"""
45+
46+
47+
async def test_MonitoredCache_properties(fake_client: CloudSQLClient) -> None:
48+
"""
49+
Test that MonitoredCache properties work as expected.
50+
"""
51+
conn_name = ConnectionName("test-project", "test-region", "test-instance")
52+
cache = LazyRefreshCache(
53+
conn_name,
54+
client=fake_client,
55+
keys=asyncio.create_task(generate_keys()),
56+
enable_iam_auth=False,
57+
)
58+
monitored_cache = MonitoredCache(cache, 30, DefaultResolver())
59+
# test that ticker is not set for instance not using domain name
60+
assert monitored_cache.domain_name_ticker is None
61+
# test closed property
62+
assert monitored_cache.closed is False
63+
# close cache and make sure property is updated
64+
await monitored_cache.close()
65+
assert monitored_cache.closed is True
66+
67+
68+
async def test_MonitoredCache_with_DnsResolver(fake_client: CloudSQLClient) -> None:
69+
"""
70+
Test that MonitoredCache with DnsResolver work as expected.
71+
"""
72+
conn_name = ConnectionName(
73+
"test-project", "test-region", "test-instance", "db.example.com"
74+
)
75+
cache = LazyRefreshCache(
76+
conn_name,
77+
client=fake_client,
78+
keys=asyncio.create_task(generate_keys()),
79+
enable_iam_auth=False,
80+
)
81+
# Patch DNS resolution with valid TXT records
82+
with patch("dns.asyncresolver.Resolver.resolve") as mock_connect:
83+
answer = dns.resolver.Answer(
84+
"db.example.com",
85+
dns.rdatatype.TXT,
86+
dns.rdataclass.IN,
87+
dns.message.from_text(query_text),
88+
)
89+
mock_connect.return_value = answer
90+
resolver = DnsResolver()
91+
resolver.port = 5053
92+
monitored_cache = MonitoredCache(cache, 30, resolver)
93+
# test that ticker is set for instance using domain name
94+
assert type(monitored_cache.domain_name_ticker) is asyncio.Task
95+
# test closed property
96+
assert monitored_cache.closed is False
97+
# close cache and make sure property is updated
98+
await monitored_cache.close()
99+
assert monitored_cache.closed is True
100+
# domain name ticker should be set back to None
101+
assert monitored_cache.domain_name_ticker is None
102+
103+
104+
async def test_MonitoredCache_with_disabled_failover(
105+
fake_client: CloudSQLClient,
106+
) -> None:
107+
"""
108+
Test that MonitoredCache disables DNS polling with failover_period=0
109+
"""
110+
conn_name = ConnectionName(
111+
"test-project", "test-region", "test-instance", "db.example.com"
112+
)
113+
cache = LazyRefreshCache(
114+
conn_name,
115+
client=fake_client,
116+
keys=asyncio.create_task(generate_keys()),
117+
enable_iam_auth=False,
118+
)
119+
monitored_cache = MonitoredCache(cache, 0, DnsResolver())
120+
# test that ticker is not set when failover is disabled
121+
assert monitored_cache.domain_name_ticker is None
122+
# test closed property
123+
assert monitored_cache.closed is False
124+
# close cache and make sure property is updated
125+
await monitored_cache.close()
126+
assert monitored_cache.closed is True
127+
128+
129+
@pytest.mark.usefixtures("server")
130+
async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> None:
131+
"""
132+
Test that MonitoredCache is closed when _check_domain_name has domain change.
133+
"""
134+
conn_name = ConnectionName(
135+
"my-project", "my-region", "my-instance", "db.example.com"
136+
)
137+
cache = LazyRefreshCache(
138+
conn_name,
139+
client=fake_client,
140+
keys=asyncio.create_task(generate_keys()),
141+
enable_iam_auth=False,
142+
)
143+
# Patch DNS resolution with valid TXT records
144+
with patch("dns.asyncresolver.Resolver.resolve") as mock_connect:
145+
answer = dns.resolver.Answer(
146+
"db.example.com",
147+
dns.rdatatype.TXT,
148+
dns.rdataclass.IN,
149+
dns.message.from_text(query_text),
150+
)
151+
mock_connect.return_value = answer
152+
resolver = DnsResolver()
153+
resolver.port = 5053
154+
155+
# configure a local socket
156+
ip_addr = "127.0.0.1"
157+
context = await create_ssl_context()
158+
sock = context.wrap_socket(
159+
socket.create_connection((ip_addr, 3307)),
160+
server_hostname=ip_addr,
161+
do_handshake_on_connect=False,
162+
)
163+
# verify socket is open
164+
assert sock.fileno() != -1
165+
# set failover to 0 to disable polling
166+
monitored_cache = MonitoredCache(cache, 0, resolver)
167+
# add socket to cache
168+
monitored_cache.sockets = [sock]
169+
# check cache is not closed
170+
assert monitored_cache.closed is False
171+
# call _check_domain_name and verify cache is closed
172+
await monitored_cache._check_domain_name()
173+
assert monitored_cache.closed is True
174+
# verify socket was closed
175+
assert sock.fileno() == -1
176+
177+
178+
@pytest.mark.usefixtures("server")
179+
async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient) -> None:
180+
"""
181+
Test that MonitoredCache._purge_closed_sockets removes closed sockets from
182+
cache.
183+
"""
184+
conn_name = ConnectionName(
185+
"my-project", "my-region", "my-instance", "db.example.com"
186+
)
187+
cache = LazyRefreshCache(
188+
conn_name,
189+
client=fake_client,
190+
keys=asyncio.create_task(generate_keys()),
191+
enable_iam_auth=False,
192+
)
193+
# configure a local socket
194+
ip_addr = "127.0.0.1"
195+
context = await create_ssl_context()
196+
sock = context.wrap_socket(
197+
socket.create_connection((ip_addr, 3307)),
198+
server_hostname=ip_addr,
199+
do_handshake_on_connect=False,
200+
)
201+
202+
# set failover to 0 to disable polling
203+
monitored_cache = MonitoredCache(cache, 0, DnsResolver())
204+
# verify socket is open
205+
assert sock.fileno() != -1
206+
# add socket to cache
207+
monitored_cache.sockets = [sock]
208+
# call _purge_closed_sockets and verify socket remains
209+
monitored_cache._purge_closed_sockets()
210+
# verify socket is still open
211+
assert sock.fileno() != -1
212+
assert len(monitored_cache.sockets) == 1
213+
# close socket
214+
sock.close()
215+
# call _purge_closed_sockets and verify socket is clsoed
216+
monitored_cache._purge_closed_sockets()
217+
assert len(monitored_cache.sockets) == 0
218+
assert sock.fileno() == -1

0 commit comments

Comments
 (0)