Skip to content

Commit 996ad00

Browse files
Tasssadarbdraco
andauthored
fix: leak of aiodns.DNSResolver when ClientSession is closed (#11150)
Co-authored-by: J. Nick Koston <[email protected]>
1 parent 0544f11 commit 996ad00

File tree

5 files changed

+41
-5
lines changed

5 files changed

+41
-5
lines changed

CHANGES/11150.bugfix.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fixed leak of ``aiodns.DNSResolver`` when :py:class:`~aiohttp.TCPConnector` is closed and no resolver was passed when creating the connector -- by :user:`Tasssadar`.
2+
3+
This was a regression introduced in version 3.12.0 (:pr:`10897`).

CONTRIBUTORS.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ Vladimir Shulyak
378378
Vladimir Zakharov
379379
Vladyslav Bohaichuk
380380
Vladyslav Bondar
381+
Vojtěch Boček
381382
W. Trevor King
382383
Wei Lin
383384
Weiwei Wang

aiohttp/connector.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -883,9 +883,14 @@ def __init__(
883883
"got {!r} instead.".format(ssl)
884884
)
885885
self._ssl = ssl
886+
887+
self._resolver: AbstractResolver
886888
if resolver is None:
887-
resolver = DefaultResolver()
888-
self._resolver: AbstractResolver = resolver
889+
self._resolver = DefaultResolver()
890+
self._resolver_owner = True
891+
else:
892+
self._resolver = resolver
893+
self._resolver_owner = False
889894

890895
self._use_dns_cache = use_dns_cache
891896
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
@@ -912,6 +917,12 @@ def _close_immediately(self) -> List[Awaitable[object]]:
912917

913918
return waiters
914919

920+
async def close(self) -> None:
921+
"""Close all opened transports."""
922+
if self._resolver_owner:
923+
await self._resolver.close()
924+
await super().close()
925+
915926
@property
916927
def family(self) -> int:
917928
"""Socket family like AF_INET."""
@@ -1567,7 +1578,8 @@ def __init__(
15671578
limit_per_host=limit_per_host,
15681579
)
15691580
if not isinstance(
1570-
self._loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
1581+
self._loop,
1582+
asyncio.ProactorEventLoop, # type: ignore[attr-defined]
15711583
):
15721584
raise RuntimeError(
15731585
"Named Pipes only available in proactor loop under windows"

aiohttp/resolver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,10 @@ def release_resolver(
220220
loop: The event loop the resolver was using.
221221
"""
222222
# Remove client from its loop's tracking
223-
if loop not in self._loop_data:
223+
current_loop_data = self._loop_data.get(loop)
224+
if current_loop_data is None:
224225
return
225-
resolver, client_set = self._loop_data[loop]
226+
resolver, client_set = current_loop_data
226227
client_set.discard(client)
227228
# If no more clients for this loop, cancel and remove its resolver
228229
if not client_set:

tests/test_connector.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,7 @@ async def test_tcp_connector_dns_cache_not_expired(
13011301
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
13021302
conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10)
13031303
m_resolver().resolve.return_value = dns_response()
1304+
m_resolver().close = mock.AsyncMock()
13041305
await conn._resolve_host("localhost", 8080)
13051306
await conn._resolve_host("localhost", 8080)
13061307
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)
@@ -1314,6 +1315,7 @@ async def test_tcp_connector_dns_cache_forever(
13141315
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
13151316
conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10)
13161317
m_resolver().resolve.return_value = dns_response()
1318+
m_resolver().close = mock.AsyncMock()
13171319
await conn._resolve_host("localhost", 8080)
13181320
await conn._resolve_host("localhost", 8080)
13191321
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)
@@ -1327,6 +1329,7 @@ async def test_tcp_connector_use_dns_cache_disabled(
13271329
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
13281330
conn = aiohttp.TCPConnector(use_dns_cache=False)
13291331
m_resolver().resolve.side_effect = [dns_response(), dns_response()]
1332+
m_resolver().close = mock.AsyncMock()
13301333
await conn._resolve_host("localhost", 8080)
13311334
await conn._resolve_host("localhost", 8080)
13321335
m_resolver().resolve.assert_has_calls(
@@ -1345,6 +1348,7 @@ async def test_tcp_connector_dns_throttle_requests(
13451348
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
13461349
conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10)
13471350
m_resolver().resolve.return_value = dns_response()
1351+
m_resolver().close = mock.AsyncMock()
13481352
t = loop.create_task(conn._resolve_host("localhost", 8080))
13491353
t2 = loop.create_task(conn._resolve_host("localhost", 8080))
13501354
await asyncio.sleep(0)
@@ -1365,6 +1369,7 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(
13651369
conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10)
13661370
e = Exception()
13671371
m_resolver().resolve.side_effect = e
1372+
m_resolver().close = mock.AsyncMock()
13681373
r1 = loop.create_task(conn._resolve_host("localhost", 8080))
13691374
r2 = loop.create_task(conn._resolve_host("localhost", 8080))
13701375
await asyncio.sleep(0)
@@ -1383,6 +1388,7 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close(
13831388
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
13841389
conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10)
13851390
m_resolver().resolve.return_value = dns_response()
1391+
m_resolver().close = mock.AsyncMock()
13861392
t = loop.create_task(conn._resolve_host("localhost", 8080))
13871393
f = loop.create_task(conn._resolve_host("localhost", 8080))
13881394

@@ -1429,6 +1435,7 @@ def exception_handler(loop: asyncio.AbstractEventLoop, context: object) -> None:
14291435
use_dns_cache=False,
14301436
)
14311437
m_resolver().resolve.return_value = dns_response_error()
1438+
m_resolver().close = mock.AsyncMock()
14321439
f = loop.create_task(conn._create_direct_connection(req, [], ClientTimeout(0)))
14331440

14341441
await asyncio.sleep(0)
@@ -1466,6 +1473,7 @@ async def test_tcp_connector_dns_tracing(
14661473
conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10)
14671474

14681475
m_resolver().resolve.return_value = dns_response()
1476+
m_resolver().close = mock.AsyncMock()
14691477

14701478
await conn._resolve_host("localhost", 8080, traces=traces)
14711479
on_dns_resolvehost_start.assert_called_once_with(
@@ -1509,6 +1517,7 @@ async def test_tcp_connector_dns_tracing_cache_disabled(
15091517
conn = aiohttp.TCPConnector(use_dns_cache=False)
15101518

15111519
m_resolver().resolve.side_effect = [dns_response(), dns_response()]
1520+
m_resolver().close = mock.AsyncMock()
15121521

15131522
await conn._resolve_host("localhost", 8080, traces=traces)
15141523

@@ -1565,6 +1574,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests(
15651574
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
15661575
conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10)
15671576
m_resolver().resolve.return_value = dns_response()
1577+
m_resolver().close = mock.AsyncMock()
15681578
t = loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
15691579
t1 = loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
15701580
await asyncio.sleep(0)
@@ -1583,6 +1593,14 @@ async def test_tcp_connector_dns_tracing_throttle_requests(
15831593
await conn.close()
15841594

15851595

1596+
async def test_tcp_connector_close_resolver() -> None:
1597+
m_resolver = mock.AsyncMock()
1598+
with mock.patch("aiohttp.connector.DefaultResolver", return_value=m_resolver):
1599+
conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10)
1600+
await conn.close()
1601+
m_resolver.close.assert_awaited_once()
1602+
1603+
15861604
async def test_dns_error(loop: asyncio.AbstractEventLoop) -> None:
15871605
connector = aiohttp.TCPConnector()
15881606
with mock.patch.object(
@@ -3834,6 +3852,7 @@ async def resolve_response() -> List[ResolveResult]:
38343852

38353853
with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver:
38363854
m_resolver().resolve.return_value = resolve_response()
3855+
m_resolver().close = mock.AsyncMock()
38373856

38383857
connector = TCPConnector()
38393858
traces = [DummyTracer()]

0 commit comments

Comments
 (0)