1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import dns .message
16+ import dns .rdataclass
17+ import dns .rdatatype
18+ import dns .resolver
19+ from mock import patch
1520import pytest
1621
1722from google .cloud .sql .connector .connection_name import ConnectionName
1823from google .cloud .sql .connector .exceptions import DnsResolutionError
1924from google .cloud .sql .connector .resolver import DefaultResolver
2025from google .cloud .sql .connector .resolver import DnsResolver
2126
22- conn_str = "test -project:test -region:test -instance"
23- conn_name = ConnectionName ("test -project" , "test -region" , "test -instance" )
27+ conn_str = "my -project:my -region:my -instance"
28+ conn_name = ConnectionName ("my -project" , "my -region" , "my -instance" )
2429
2530
2631async def test_DefaultResolver () -> None :
@@ -37,39 +42,87 @@ async def test_DnsResolver_with_conn_str() -> None:
3742 assert result == conn_name
3843
3944
40- @pytest .mark .usefixtures ("dns_server" )
45+ query_text = """id 1234
46+ opcode QUERY
47+ rcode NOERROR
48+ flags QR AA RD RA
49+ ;QUESTION
50+ db.example.com. IN TXT
51+ ;ANSWER
52+ db.example.com. 0 IN TXT "test-project:test-region:test-instance"
53+ db.example.com. 0 IN TXT "my-project:my-region:my-instance"
54+ ;AUTHORITY
55+ ;ADDITIONAL
56+ """
57+
58+
4159async def test_DnsResolver_with_dns_name () -> None :
42- """Test DnsResolver resolves TXT record into proper instance connection name."""
43- resolver = DnsResolver ()
44- resolver .port = 5053
45- result = await resolver .resolve ("db.example.com" )
46- assert result == conn_name
60+ """Test DnsResolver resolves TXT record into proper instance connection name.
61+
62+ Should sort valid TXT records alphabetically and take first one.
63+ """
64+ # Patch DNS resolution with valid TXT records
65+ with patch ("dns.asyncresolver.Resolver.resolve" ) as mock_connect :
66+ answer = dns .resolver .Answer (
67+ "db.example.com" ,
68+ dns .rdatatype .TXT ,
69+ dns .rdataclass .IN ,
70+ dns .message .from_text (query_text ),
71+ )
72+ mock_connect .return_value = answer
73+ resolver = DnsResolver ()
74+ resolver .port = 5053
75+ # Resolution should return first value sorted alphabetically
76+ result = await resolver .resolve ("db.example.com" )
77+ assert result == conn_name
78+
79+
80+ query_text_malformed = """id 1234
81+ opcode QUERY
82+ rcode NOERROR
83+ flags QR AA RD RA
84+ ;QUESTION
85+ bad.example.com. IN TXT
86+ ;ANSWER
87+ bad.example.com. 0 IN TXT "malformed-instance-name"
88+ ;AUTHORITY
89+ ;ADDITIONAL
90+ """
4791
4892
49- @pytest .mark .usefixtures ("dns_server" )
5093async def test_DnsResolver_with_malformed_txt () -> None :
5194 """Test DnsResolver with TXT record that holds malformed instance connection name.
5295
5396 Should throw DnsResolutionError
5497 """
55- resolver = DnsResolver ()
56- resolver . port = 5053
57- with pytest . raises ( DnsResolutionError ) as exc_info :
58- await resolver . resolve ( "bad.example.com" )
59- assert (
60- exc_info . value . args [ 0 ]
61- == "Unable to parse TXT record for `bad.example.com` -> `bad-instance-name`"
98+ # patch DNS resolution with malformed TXT record
99+ with patch ( "dns.asyncresolver.Resolver.resolve" ) as mock_connect :
100+ answer = dns . resolver . Answer (
101+ "bad.example.com" ,
102+ dns . rdatatype . TXT ,
103+ dns . rdataclass . IN ,
104+ dns . message . from_text ( query_text_malformed ),
62105 )
106+ mock_connect .return_value = answer
107+ resolver = DnsResolver ()
108+ resolver .port = 5053
109+ with pytest .raises (DnsResolutionError ) as exc_info :
110+ await resolver .resolve ("bad.example.com" )
111+ assert (
112+ exc_info .value .args [0 ]
113+ == "Unable to parse TXT record for `bad.example.com` -> `malformed-instance-name`"
114+ )
63115
64116
65- @pytest .mark .usefixtures ("dns_server" )
66117async def test_DnsResolver_with_bad_dns_name () -> None :
67118 """Test DnsResolver with bad dns name.
68119
69120 Should throw DnsResolutionError
70121 """
71122 resolver = DnsResolver ()
72123 resolver .port = 5053
124+ # set lifetime to 1 second for shorter timeout
125+ resolver .lifetime = 1
73126 with pytest .raises (DnsResolutionError ) as exc_info :
74127 await resolver .resolve ("bad.dns.com" )
75128 assert exc_info .value .args [0 ] == "Unable to resolve TXT record for `bad.dns.com`"
0 commit comments