Skip to content

Commit 834ea61

Browse files
Tasssadarbdraco
andauthored
[PR #11150/996ad00 backport][3.13] fix: leak of aiodns.DNSResolver when ClientSession is closed (#11151)
Co-authored-by: J. Nick Koston <[email protected]>
1 parent 5991031 commit 834ea61

File tree

5 files changed

+41
-8
lines changed

5 files changed

+41
-8
lines changed

CHANGES/11150.bugfix.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fixed leak of ``aiodns.DNSResolver`` when :py:class:`~aiohttp.TCPConnector` is closed and no resolver was passed when creating the connector -- by :user:`Tasssadar`.
2+
3+
This was a regression introduced in version 3.12.0 (:pr:`10897`).

CONTRIBUTORS.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ Vladimir Shulyak
368368
Vladimir Zakharov
369369
Vladyslav Bohaichuk
370370
Vladyslav Bondar
371+
Vojtěch Boček
371372
W. Trevor King
372373
Wei Lin
373374
Weiwei Wang

aiohttp/connector.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -926,9 +926,14 @@ def __init__(
926926
)
927927

928928
self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)
929+
930+
self._resolver: AbstractResolver
929931
if resolver is None:
930-
resolver = DefaultResolver(loop=self._loop)
931-
self._resolver = resolver
932+
self._resolver = DefaultResolver(loop=self._loop)
933+
self._resolver_owner = True
934+
else:
935+
self._resolver = resolver
936+
self._resolver_owner = False
932937

933938
self._use_dns_cache = use_dns_cache
934939
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
@@ -956,6 +961,12 @@ def _close(self) -> List[Awaitable[object]]:
956961

957962
return waiters
958963

964+
async def close(self) -> None:
965+
"""Close all opened transports."""
966+
if self._resolver_owner:
967+
await self._resolver.close()
968+
await super().close()
969+
959970
@property
960971
def family(self) -> int:
961972
"""Socket family like AF_INET."""
@@ -1709,7 +1720,8 @@ def __init__(
17091720
loop=loop,
17101721
)
17111722
if not isinstance(
1712-
self._loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
1723+
self._loop,
1724+
asyncio.ProactorEventLoop, # type: ignore[attr-defined]
17131725
):
17141726
raise RuntimeError(
17151727
"Named Pipes only available in proactor loop under windows"

aiohttp/resolver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,10 @@ def release_resolver(
258258
loop: The event loop the resolver was using.
259259
"""
260260
# Remove client from its loop's tracking
261-
if loop not in self._loop_data:
261+
current_loop_data = self._loop_data.get(loop)
262+
if current_loop_data is None:
262263
return
263-
resolver, client_set = self._loop_data[loop]
264+
resolver, client_set = current_loop_data
264265
client_set.discard(client)
265266
# If no more clients for this loop, cancel and remove its resolver
266267
if not client_set:

tests/test_connector.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,7 @@ async def test_tcp_connector_dns_cache_not_expired(loop, dns_response) -> None:
12701270
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
12711271
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
12721272
m_resolver().resolve.return_value = dns_response()
1273+
m_resolver().close = mock.AsyncMock()
12731274
await conn._resolve_host("localhost", 8080)
12741275
await conn._resolve_host("localhost", 8080)
12751276
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)
@@ -1281,6 +1282,7 @@ async def test_tcp_connector_dns_cache_forever(loop, dns_response) -> None:
12811282
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
12821283
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
12831284
m_resolver().resolve.return_value = dns_response()
1285+
m_resolver().close = mock.AsyncMock()
12841286
await conn._resolve_host("localhost", 8080)
12851287
await conn._resolve_host("localhost", 8080)
12861288
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)
@@ -1292,6 +1294,7 @@ async def test_tcp_connector_use_dns_cache_disabled(loop, dns_response) -> None:
12921294
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
12931295
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False)
12941296
m_resolver().resolve.side_effect = [dns_response(), dns_response()]
1297+
m_resolver().close = mock.AsyncMock()
12951298
await conn._resolve_host("localhost", 8080)
12961299
await conn._resolve_host("localhost", 8080)
12971300
m_resolver().resolve.assert_has_calls(
@@ -1308,6 +1311,7 @@ async def test_tcp_connector_dns_throttle_requests(loop, dns_response) -> None:
13081311
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
13091312
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
13101313
m_resolver().resolve.return_value = dns_response()
1314+
m_resolver().close = mock.AsyncMock()
13111315
loop.create_task(conn._resolve_host("localhost", 8080))
13121316
loop.create_task(conn._resolve_host("localhost", 8080))
13131317
await asyncio.sleep(0)
@@ -1322,6 +1326,7 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(loop) -> Non
13221326
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
13231327
e = Exception()
13241328
m_resolver().resolve.side_effect = e
1329+
m_resolver().close = mock.AsyncMock()
13251330
r1 = loop.create_task(conn._resolve_host("localhost", 8080))
13261331
r2 = loop.create_task(conn._resolve_host("localhost", 8080))
13271332
await asyncio.sleep(0)
@@ -1337,10 +1342,10 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(loop) -> Non
13371342
async def test_tcp_connector_dns_throttle_requests_cancelled_when_close(
13381343
loop, dns_response
13391344
):
1340-
13411345
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
13421346
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
13431347
m_resolver().resolve.return_value = dns_response()
1348+
m_resolver().close = mock.AsyncMock()
13441349
loop.create_task(conn._resolve_host("localhost", 8080))
13451350
f = loop.create_task(conn._resolve_host("localhost", 8080))
13461351

@@ -1367,7 +1372,6 @@ async def coro():
13671372
async def test_tcp_connector_cancel_dns_error_captured(
13681373
loop, dns_response_error
13691374
) -> None:
1370-
13711375
exception_handler_called = False
13721376

13731377
def exception_handler(loop, context):
@@ -1384,6 +1388,7 @@ def exception_handler(loop, context):
13841388
use_dns_cache=False,
13851389
)
13861390
m_resolver().resolve.return_value = dns_response_error()
1391+
m_resolver().close = mock.AsyncMock()
13871392
f = loop.create_task(conn._create_direct_connection(req, [], ClientTimeout(0)))
13881393

13891394
await asyncio.sleep(0)
@@ -1419,6 +1424,7 @@ async def test_tcp_connector_dns_tracing(loop, dns_response) -> None:
14191424
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
14201425

14211426
m_resolver().resolve.return_value = dns_response()
1427+
m_resolver().close = mock.AsyncMock()
14221428

14231429
await conn._resolve_host("localhost", 8080, traces=traces)
14241430
on_dns_resolvehost_start.assert_called_once_with(
@@ -1460,6 +1466,7 @@ async def test_tcp_connector_dns_tracing_cache_disabled(loop, dns_response) -> N
14601466
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False)
14611467

14621468
m_resolver().resolve.side_effect = [dns_response(), dns_response()]
1469+
m_resolver().close = mock.AsyncMock()
14631470

14641471
await conn._resolve_host("localhost", 8080, traces=traces)
14651472

@@ -1514,6 +1521,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response) -
15141521
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
15151522
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=True, ttl_dns_cache=10)
15161523
m_resolver().resolve.return_value = dns_response()
1524+
m_resolver().close = mock.AsyncMock()
15171525
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
15181526
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
15191527
await asyncio.sleep(0)
@@ -1528,6 +1536,14 @@ async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response) -
15281536
await conn.close()
15291537

15301538

1539+
async def test_tcp_connector_close_resolver() -> None:
1540+
m_resolver = mock.AsyncMock()
1541+
with mock.patch("aiohttp.connector.DefaultResolver", return_value=m_resolver):
1542+
conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10)
1543+
await conn.close()
1544+
m_resolver.close.assert_awaited_once()
1545+
1546+
15311547
async def test_dns_error(loop) -> None:
15321548
connector = aiohttp.TCPConnector(loop=loop)
15331549
connector._resolve_host = mock.AsyncMock(
@@ -2896,7 +2912,6 @@ async def f():
28962912

28972913

28982914
async def test_connect_with_limit_cancelled(loop) -> None:
2899-
29002915
proto = create_mocked_conn()
29012916
proto.is_connected.return_value = True
29022917

@@ -3691,6 +3706,7 @@ async def resolve_response() -> List[ResolveResult]:
36913706

36923707
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
36933708
m_resolver().resolve.return_value = resolve_response()
3709+
m_resolver().close = mock.AsyncMock()
36943710

36953711
connector = TCPConnector()
36963712
traces = [DummyTracer()]

0 commit comments

Comments
 (0)