1
1
"""Tests for multiaddr resolvers."""
2
2
3
- import socket
4
3
import sys
5
- from unittest .mock import patch
4
+ from unittest .mock import AsyncMock , patch
6
5
6
+ import dns .resolver
7
7
import pytest
8
8
import trio
9
9
10
10
from multiaddr import Multiaddr
11
- from multiaddr .exceptions import RecursionLimitError , ResolutionError
11
+ from multiaddr .exceptions import RecursionLimitError
12
12
from multiaddr .resolvers import DNSResolver
13
13
14
14
if sys .version_info >= (3 , 11 ):
@@ -35,10 +35,28 @@ async def test_resolve_non_dns_addr(dns_resolver):
35
35
@pytest .mark .trio
36
36
async def test_resolve_dns_addr (dns_resolver ):
37
37
"""Test resolving a DNS multiaddr."""
38
- with patch ("socket.getaddrinfo" ) as mock_getaddrinfo :
39
- mock_getaddrinfo .return_value = [
40
- (socket .AF_INET , socket .SOCK_STREAM , socket .IPPROTO_TCP , "" , ("127.0.0.1" , 0 ))
41
- ]
38
+ # Create mock DNS answer for A record (IPv4)
39
+ mock_answer_a = AsyncMock ()
40
+ mock_rdata_a = AsyncMock ()
41
+ mock_rdata_a .address = "127.0.0.1"
42
+ mock_answer_a .__iter__ .return_value = [mock_rdata_a ]
43
+
44
+ # Create mock DNS answer for AAAA record (IPv6) - return empty to avoid conflicts
45
+ mock_answer_aaaa = AsyncMock ()
46
+ mock_answer_aaaa .__iter__ .return_value = []
47
+
48
+ with patch .object (dns_resolver ._resolver , 'resolve' ) as mock_resolve :
49
+ # Configure the mock to return different results based on record type
50
+ async def mock_resolve_side_effect (hostname , record_type ):
51
+ if record_type == "A" :
52
+ return mock_answer_a
53
+ elif record_type == "AAAA" :
54
+ return mock_answer_aaaa
55
+ else :
56
+ raise dns .resolver .NXDOMAIN ()
57
+
58
+ mock_resolve .side_effect = mock_resolve_side_effect
59
+
42
60
ma = Multiaddr ("/dnsaddr/example.com" )
43
61
result = await dns_resolver .resolve (ma )
44
62
assert len (result ) == 1
@@ -49,10 +67,28 @@ async def test_resolve_dns_addr(dns_resolver):
49
67
@pytest .mark .trio
50
68
async def test_resolve_dns_addr_with_peer_id (dns_resolver ):
51
69
"""Test resolving a DNS multiaddr with a peer ID."""
52
- with patch ("socket.getaddrinfo" ) as mock_getaddrinfo :
53
- mock_getaddrinfo .return_value = [
54
- (socket .AF_INET , socket .SOCK_STREAM , socket .IPPROTO_TCP , "" , ("127.0.0.1" , 0 ))
55
- ]
70
+ # Create mock DNS answer for A record (IPv4)
71
+ mock_answer_a = AsyncMock ()
72
+ mock_rdata_a = AsyncMock ()
73
+ mock_rdata_a .address = "127.0.0.1"
74
+ mock_answer_a .__iter__ .return_value = [mock_rdata_a ]
75
+
76
+ # Create mock DNS answer for AAAA record (IPv6) - return empty to avoid conflicts
77
+ mock_answer_aaaa = AsyncMock ()
78
+ mock_answer_aaaa .__iter__ .return_value = []
79
+
80
+ with patch .object (dns_resolver ._resolver , 'resolve' ) as mock_resolve :
81
+ # Configure the mock to return different results based on record type
82
+ async def mock_resolve_side_effect (hostname , record_type ):
83
+ if record_type == "A" :
84
+ return mock_answer_a
85
+ elif record_type == "AAAA" :
86
+ return mock_answer_aaaa
87
+ else :
88
+ raise dns .resolver .NXDOMAIN ()
89
+
90
+ mock_resolve .side_effect = mock_resolve_side_effect
91
+
56
92
ma = Multiaddr ("/dnsaddr/example.com/p2p/QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7wjh53Qk" )
57
93
result = await dns_resolver .resolve (ma )
58
94
assert len (result ) == 1
@@ -64,11 +100,28 @@ async def test_resolve_dns_addr_with_peer_id(dns_resolver):
64
100
@pytest .mark .trio
65
101
async def test_resolve_recursive_dns_addr (dns_resolver ):
66
102
"""Test resolving a recursive DNS multiaddr."""
67
- with patch ("socket.getaddrinfo" ) as mock_getaddrinfo :
68
- mock_getaddrinfo .side_effect = [
69
- [(socket .AF_INET , socket .SOCK_STREAM , socket .IPPROTO_TCP , "" , ("127.0.0.1" , 0 ))],
70
- [(socket .AF_INET , socket .SOCK_STREAM , socket .IPPROTO_TCP , "" , ("192.168.1.1" , 0 ))],
71
- ]
103
+ # Create mock DNS answer for A record (IPv4)
104
+ mock_answer_a = AsyncMock ()
105
+ mock_rdata_a = AsyncMock ()
106
+ mock_rdata_a .address = "127.0.0.1"
107
+ mock_answer_a .__iter__ .return_value = [mock_rdata_a ]
108
+
109
+ # Create mock DNS answer for AAAA record (IPv6) - return empty to avoid conflicts
110
+ mock_answer_aaaa = AsyncMock ()
111
+ mock_answer_aaaa .__iter__ .return_value = []
112
+
113
+ with patch .object (dns_resolver ._resolver , 'resolve' ) as mock_resolve :
114
+ # Configure the mock to return different results based on record type
115
+ async def mock_resolve_side_effect (hostname , record_type ):
116
+ if record_type == "A" :
117
+ return mock_answer_a
118
+ elif record_type == "AAAA" :
119
+ return mock_answer_aaaa
120
+ else :
121
+ raise dns .resolver .NXDOMAIN ()
122
+
123
+ mock_resolve .side_effect = mock_resolve_side_effect
124
+
72
125
ma = Multiaddr ("/dnsaddr/example.com" )
73
126
result = await dns_resolver .resolve (ma , {"max_recursive_depth" : 2 })
74
127
assert len (result ) == 1
@@ -87,20 +140,38 @@ async def test_resolve_recursion_limit(dns_resolver):
87
140
@pytest .mark .trio
88
141
async def test_resolve_dns_addr_error (dns_resolver ):
89
142
"""Test handling DNS resolution errors."""
90
- with patch ("socket.getaddrinfo" ) as mock_getaddrinfo :
91
- mock_getaddrinfo .side_effect = socket .gaierror ("DNS resolution failed" )
143
+ with patch .object (dns_resolver ._resolver , 'resolve' , side_effect = dns .resolver .NXDOMAIN ):
92
144
ma = Multiaddr ("/dnsaddr/example.com" )
93
- with pytest .raises (ResolutionError ):
94
- await dns_resolver .resolve (ma )
145
+ # When DNS resolution fails, the resolver should return the original multiaddr
146
+ result = await dns_resolver .resolve (ma )
147
+ assert result == [ma ]
95
148
96
149
97
150
@pytest .mark .trio
98
151
async def test_resolve_dns_addr_with_quotes (dns_resolver ):
99
152
"""Test resolving DNS records with quoted strings."""
100
- with patch ("socket.getaddrinfo" ) as mock_getaddrinfo :
101
- mock_getaddrinfo .return_value = [
102
- (socket .AF_INET , socket .SOCK_STREAM , socket .IPPROTO_TCP , "" , ("127.0.0.1" , 0 ))
103
- ]
153
+ # Create mock DNS answer for A record (IPv4)
154
+ mock_answer_a = AsyncMock ()
155
+ mock_rdata_a = AsyncMock ()
156
+ mock_rdata_a .address = "127.0.0.1"
157
+ mock_answer_a .__iter__ .return_value = [mock_rdata_a ]
158
+
159
+ # Create mock DNS answer for AAAA record (IPv6) - return empty to avoid conflicts
160
+ mock_answer_aaaa = AsyncMock ()
161
+ mock_answer_aaaa .__iter__ .return_value = []
162
+
163
+ with patch .object (dns_resolver ._resolver , 'resolve' ) as mock_resolve :
164
+ # Configure the mock to return different results based on record type
165
+ async def mock_resolve_side_effect (hostname , record_type ):
166
+ if record_type == "A" :
167
+ return mock_answer_a
168
+ elif record_type == "AAAA" :
169
+ return mock_answer_aaaa
170
+ else :
171
+ raise dns .resolver .NXDOMAIN ()
172
+
173
+ mock_resolve .side_effect = mock_resolve_side_effect
174
+
104
175
ma = Multiaddr ("/dnsaddr/example.com" )
105
176
result = await dns_resolver .resolve (ma )
106
177
assert len (result ) == 1
@@ -111,45 +182,56 @@ async def test_resolve_dns_addr_with_quotes(dns_resolver):
111
182
@pytest .mark .trio
112
183
async def test_resolve_dns_addr_with_mixed_quotes (dns_resolver ):
113
184
"""Test resolving DNS records with mixed quotes."""
114
- with patch ("socket.getaddrinfo" ) as mock_getaddrinfo :
115
- mock_getaddrinfo .return_value = [
116
- (socket .AF_INET , socket .SOCK_STREAM , socket .IPPROTO_TCP , "" , ("127.0.0.1" , 0 ))
117
- ]
185
+ # Create mock DNS answer for A record (IPv4)
186
+ mock_answer_a = AsyncMock ()
187
+ mock_rdata_a = AsyncMock ()
188
+ mock_rdata_a .address = "127.0.0.1"
189
+ mock_answer_a .__iter__ .return_value = [mock_rdata_a ]
190
+
191
+ # Create mock DNS answer for AAAA record (IPv6) - return empty to avoid conflicts
192
+ mock_answer_aaaa = AsyncMock ()
193
+ mock_answer_aaaa .__iter__ .return_value = []
194
+
195
+ with patch .object (dns_resolver ._resolver , 'resolve' ) as mock_resolve :
196
+ # Configure the mock to return different results based on record type
197
+ async def mock_resolve_side_effect (hostname , record_type ):
198
+ if record_type == "A" :
199
+ return mock_answer_a
200
+ elif record_type == "AAAA" :
201
+ return mock_answer_aaaa
202
+ else :
203
+ raise dns .resolver .NXDOMAIN ()
204
+
205
+ mock_resolve .side_effect = mock_resolve_side_effect
206
+
118
207
ma = Multiaddr ("/dnsaddr/example.com" )
119
208
result = await dns_resolver .resolve (ma )
120
209
assert len (result ) == 1
121
210
assert result [0 ].protocols ()[0 ].name == "ip4"
122
211
assert result [0 ].value_for_protocol (result [0 ].protocols ()[0 ].code ) == "127.0.0.1"
123
212
124
213
125
- @pytest .mark .xfail (
126
- sys .version_info >= (3 , 11 ),
127
- reason = "ExceptionGroup not properly caught by pytest in async code (Python 3.11+)"
128
- )
129
214
@pytest .mark .trio
130
215
async def test_resolve_cancellation_with_error ():
131
- """Test that resolution can be cancelled and errors are properly handled."""
132
- ma = Multiaddr ("/dnsaddr/example.com" )
133
- signal = trio .CancelScope ()
216
+ """Test that DNS resolution can be cancelled."""
217
+ ma = Multiaddr ("/dnsaddr/nonexistent.example.com" )
218
+ signal = trio .CancelScope () # type: ignore[call-arg]
219
+ signal .cancelled_caught = True
134
220
dns_resolver = DNSResolver ()
135
221
136
- async def cancel_soon (scope ):
137
- await trio .sleep (0.01 )
138
- scope .cancel ()
139
-
140
- async def run_resolver ():
141
- await dns_resolver .resolve (ma , {"signal" : signal })
222
+ # Mock the DNS resolver to simulate a slow lookup that can be cancelled
223
+ async def slow_dns_resolve (* args , ** kwargs ):
224
+ await trio .sleep (0.5 ) # Long sleep to allow cancellation
225
+ raise dns .resolver .NXDOMAIN ("Domain not found" )
142
226
143
- try :
227
+ with patch .object (dns_resolver ._resolver , 'resolve' , side_effect = slow_dns_resolve ):
228
+ # Start resolution in background and cancel it
144
229
async with trio .open_nursery () as nursery :
145
- nursery .start_soon (cancel_soon , signal )
146
- nursery .start_soon (run_resolver )
147
- except BaseExceptionGroup as eg :
148
- # Check that at least one sub-exception is a cancellation
149
- assert any (
150
- isinstance (e , BaseException )
151
- and type (e ).__name__ .startswith ("Cancel" )
152
- for e in eg .exceptions
153
- )
154
- else :
155
- assert False , "Expected cancellation exception group"
230
+ # Start the resolution
231
+ nursery .start_soon (dns_resolver .resolve , ma , {"signal" : signal })
232
+
233
+ # Cancel after a short delay
234
+ await trio .sleep (0.1 )
235
+ signal .cancel ()
236
+
237
+ # The nursery should handle the cancellation
0 commit comments