Skip to content

Commit a7d8681

Browse files
committed
refactored the duplicate mock DNS resolution code into a reusable fixture
1 parent b39118f commit a7d8681

File tree

1 file changed

+46
-111
lines changed

1 file changed

+46
-111
lines changed

tests/test_resolvers.py

Lines changed: 46 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
if sys.version_info >= (3, 11):
1515
from builtins import BaseExceptionGroup
1616
else:
17+
1718
class BaseExceptionGroup(Exception):
1819
pass
1920

@@ -24,17 +25,9 @@ def dns_resolver():
2425
return DNSResolver()
2526

2627

27-
@pytest.mark.trio
28-
async def test_resolve_non_dns_addr(dns_resolver):
29-
"""Test resolving a non-DNS multiaddr."""
30-
ma = Multiaddr("/ip4/127.0.0.1/tcp/1234")
31-
result = await dns_resolver.resolve(ma)
32-
assert result == [ma]
33-
34-
35-
@pytest.mark.trio
36-
async def test_resolve_dns_addr(dns_resolver):
37-
"""Test resolving a DNS multiaddr."""
28+
@pytest.fixture
29+
def mock_dns_resolution():
30+
"""Create mock DNS resolution setup for testing."""
3831
# Create mock DNS answer for A record (IPv4)
3932
mock_answer_a = AsyncMock()
4033
mock_rdata_a = AsyncMock()
@@ -45,17 +38,35 @@ async def test_resolve_dns_addr(dns_resolver):
4538
mock_answer_aaaa = AsyncMock()
4639
mock_answer_aaaa.__iter__.return_value = []
4740

48-
with patch.object(dns_resolver._resolver, 'resolve') as mock_resolve:
49-
# Configure the mock to return different results based on record type
50-
async def mock_resolve_side_effect(hostname, record_type):
51-
if record_type == "A":
52-
return mock_answer_a
53-
elif record_type == "AAAA":
54-
return mock_answer_aaaa
55-
else:
56-
raise dns.resolver.NXDOMAIN()
41+
# Configure the mock to return different results based on record type
42+
async def mock_resolve_side_effect(hostname, record_type):
43+
if record_type == "A":
44+
return mock_answer_a
45+
elif record_type == "AAAA":
46+
return mock_answer_aaaa
47+
else:
48+
raise dns.resolver.NXDOMAIN()
49+
50+
return {
51+
"mock_answer_a": mock_answer_a,
52+
"mock_answer_aaaa": mock_answer_aaaa,
53+
"mock_resolve_side_effect": mock_resolve_side_effect,
54+
}
55+
56+
57+
@pytest.mark.trio
58+
async def test_resolve_non_dns_addr(dns_resolver):
59+
"""Test resolving a non-DNS multiaddr."""
60+
ma = Multiaddr("/ip4/127.0.0.1/tcp/1234")
61+
result = await dns_resolver.resolve(ma)
62+
assert result == [ma]
63+
5764

58-
mock_resolve.side_effect = mock_resolve_side_effect
65+
@pytest.mark.trio
66+
async def test_resolve_dns_addr(dns_resolver, mock_dns_resolution):
67+
"""Test resolving a DNS multiaddr."""
68+
with patch.object(dns_resolver._resolver, "resolve") as mock_resolve:
69+
mock_resolve.side_effect = mock_dns_resolution["mock_resolve_side_effect"]
5970

6071
ma = Multiaddr("/dnsaddr/example.com")
6172
result = await dns_resolver.resolve(ma)
@@ -65,29 +76,10 @@ async def mock_resolve_side_effect(hostname, record_type):
6576

6677

6778
@pytest.mark.trio
68-
async def test_resolve_dns_addr_with_peer_id(dns_resolver):
79+
async def test_resolve_dns_addr_with_peer_id(dns_resolver, mock_dns_resolution):
6980
"""Test resolving a DNS multiaddr with a peer ID."""
70-
# Create mock DNS answer for A record (IPv4)
71-
mock_answer_a = AsyncMock()
72-
mock_rdata_a = AsyncMock()
73-
mock_rdata_a.address = "127.0.0.1"
74-
mock_answer_a.__iter__.return_value = [mock_rdata_a]
75-
76-
# Create mock DNS answer for AAAA record (IPv6) - return empty to avoid conflicts
77-
mock_answer_aaaa = AsyncMock()
78-
mock_answer_aaaa.__iter__.return_value = []
79-
80-
with patch.object(dns_resolver._resolver, 'resolve') as mock_resolve:
81-
# Configure the mock to return different results based on record type
82-
async def mock_resolve_side_effect(hostname, record_type):
83-
if record_type == "A":
84-
return mock_answer_a
85-
elif record_type == "AAAA":
86-
return mock_answer_aaaa
87-
else:
88-
raise dns.resolver.NXDOMAIN()
89-
90-
mock_resolve.side_effect = mock_resolve_side_effect
81+
with patch.object(dns_resolver._resolver, "resolve") as mock_resolve:
82+
mock_resolve.side_effect = mock_dns_resolution["mock_resolve_side_effect"]
9183

9284
ma = Multiaddr("/dnsaddr/example.com/p2p/QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7wjh53Qk")
9385
result = await dns_resolver.resolve(ma)
@@ -98,29 +90,10 @@ async def mock_resolve_side_effect(hostname, record_type):
9890

9991

10092
@pytest.mark.trio
101-
async def test_resolve_recursive_dns_addr(dns_resolver):
93+
async def test_resolve_recursive_dns_addr(dns_resolver, mock_dns_resolution):
10294
"""Test resolving a recursive DNS multiaddr."""
103-
# Create mock DNS answer for A record (IPv4)
104-
mock_answer_a = AsyncMock()
105-
mock_rdata_a = AsyncMock()
106-
mock_rdata_a.address = "127.0.0.1"
107-
mock_answer_a.__iter__.return_value = [mock_rdata_a]
108-
109-
# Create mock DNS answer for AAAA record (IPv6) - return empty to avoid conflicts
110-
mock_answer_aaaa = AsyncMock()
111-
mock_answer_aaaa.__iter__.return_value = []
112-
113-
with patch.object(dns_resolver._resolver, 'resolve') as mock_resolve:
114-
# Configure the mock to return different results based on record type
115-
async def mock_resolve_side_effect(hostname, record_type):
116-
if record_type == "A":
117-
return mock_answer_a
118-
elif record_type == "AAAA":
119-
return mock_answer_aaaa
120-
else:
121-
raise dns.resolver.NXDOMAIN()
122-
123-
mock_resolve.side_effect = mock_resolve_side_effect
95+
with patch.object(dns_resolver._resolver, "resolve") as mock_resolve:
96+
mock_resolve.side_effect = mock_dns_resolution["mock_resolve_side_effect"]
12497

12598
ma = Multiaddr("/dnsaddr/example.com")
12699
result = await dns_resolver.resolve(ma, {"max_recursive_depth": 2})
@@ -140,37 +113,18 @@ async def test_resolve_recursion_limit(dns_resolver):
140113
@pytest.mark.trio
141114
async def test_resolve_dns_addr_error(dns_resolver):
142115
"""Test handling DNS resolution errors."""
143-
with patch.object(dns_resolver._resolver, 'resolve', side_effect=dns.resolver.NXDOMAIN):
116+
with patch.object(dns_resolver._resolver, "resolve", side_effect=dns.resolver.NXDOMAIN):
144117
ma = Multiaddr("/dnsaddr/example.com")
145118
# When DNS resolution fails, the resolver should return the original multiaddr
146119
result = await dns_resolver.resolve(ma)
147120
assert result == [ma]
148121

149122

150123
@pytest.mark.trio
151-
async def test_resolve_dns_addr_with_quotes(dns_resolver):
124+
async def test_resolve_dns_addr_with_quotes(dns_resolver, mock_dns_resolution):
152125
"""Test resolving DNS records with quoted strings."""
153-
# Create mock DNS answer for A record (IPv4)
154-
mock_answer_a = AsyncMock()
155-
mock_rdata_a = AsyncMock()
156-
mock_rdata_a.address = "127.0.0.1"
157-
mock_answer_a.__iter__.return_value = [mock_rdata_a]
158-
159-
# Create mock DNS answer for AAAA record (IPv6) - return empty to avoid conflicts
160-
mock_answer_aaaa = AsyncMock()
161-
mock_answer_aaaa.__iter__.return_value = []
162-
163-
with patch.object(dns_resolver._resolver, 'resolve') as mock_resolve:
164-
# Configure the mock to return different results based on record type
165-
async def mock_resolve_side_effect(hostname, record_type):
166-
if record_type == "A":
167-
return mock_answer_a
168-
elif record_type == "AAAA":
169-
return mock_answer_aaaa
170-
else:
171-
raise dns.resolver.NXDOMAIN()
172-
173-
mock_resolve.side_effect = mock_resolve_side_effect
126+
with patch.object(dns_resolver._resolver, "resolve") as mock_resolve:
127+
mock_resolve.side_effect = mock_dns_resolution["mock_resolve_side_effect"]
174128

175129
ma = Multiaddr("/dnsaddr/example.com")
176130
result = await dns_resolver.resolve(ma)
@@ -180,29 +134,10 @@ async def mock_resolve_side_effect(hostname, record_type):
180134

181135

182136
@pytest.mark.trio
183-
async def test_resolve_dns_addr_with_mixed_quotes(dns_resolver):
137+
async def test_resolve_dns_addr_with_mixed_quotes(dns_resolver, mock_dns_resolution):
184138
"""Test resolving DNS records with mixed quotes."""
185-
# Create mock DNS answer for A record (IPv4)
186-
mock_answer_a = AsyncMock()
187-
mock_rdata_a = AsyncMock()
188-
mock_rdata_a.address = "127.0.0.1"
189-
mock_answer_a.__iter__.return_value = [mock_rdata_a]
190-
191-
# Create mock DNS answer for AAAA record (IPv6) - return empty to avoid conflicts
192-
mock_answer_aaaa = AsyncMock()
193-
mock_answer_aaaa.__iter__.return_value = []
194-
195-
with patch.object(dns_resolver._resolver, 'resolve') as mock_resolve:
196-
# Configure the mock to return different results based on record type
197-
async def mock_resolve_side_effect(hostname, record_type):
198-
if record_type == "A":
199-
return mock_answer_a
200-
elif record_type == "AAAA":
201-
return mock_answer_aaaa
202-
else:
203-
raise dns.resolver.NXDOMAIN()
204-
205-
mock_resolve.side_effect = mock_resolve_side_effect
139+
with patch.object(dns_resolver._resolver, "resolve") as mock_resolve:
140+
mock_resolve.side_effect = mock_dns_resolution["mock_resolve_side_effect"]
206141

207142
ma = Multiaddr("/dnsaddr/example.com")
208143
result = await dns_resolver.resolve(ma)
@@ -224,7 +159,7 @@ async def slow_dns_resolve(*args, **kwargs):
224159
await trio.sleep(0.5) # Long sleep to allow cancellation
225160
raise dns.resolver.NXDOMAIN("Domain not found")
226161

227-
with patch.object(dns_resolver._resolver, 'resolve', side_effect=slow_dns_resolve):
162+
with patch.object(dns_resolver._resolver, "resolve", side_effect=slow_dns_resolve):
228163
# Start resolution in background and cancel it
229164
async with trio.open_nursery() as nursery:
230165
# Start the resolution

0 commit comments

Comments
 (0)