Skip to content

Commit 9f6c54d

Browse files
[PR #11074/e550c78a backport][3.11] Fix connector not waiting for connections to close (#11076)
Co-authored-by: J. Nick Koston <[email protected]> fixes #1925 fixes #3736
1 parent 145658f commit 9f6c54d

File tree

7 files changed

+130
-34
lines changed

7 files changed

+130
-34
lines changed

CHANGES/11074.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed connector not waiting for connections to close before returning from :meth:`~aiohttp.BaseConnector.close` (partial backport of :pr:`3733`) -- by :user:`atemate` and :user:`bdraco`.

CHANGES/1925.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
11074.bugfix.rst

aiohttp/client_proto.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .base_protocol import BaseProtocol
66
from .client_exceptions import (
7+
ClientConnectionError,
78
ClientOSError,
89
ClientPayloadError,
910
ServerDisconnectedError,
@@ -14,6 +15,7 @@
1415
EMPTY_BODY_STATUS_CODES,
1516
BaseTimerContext,
1617
set_exception,
18+
set_result,
1719
)
1820
from .http import HttpResponseParser, RawResponseMessage
1921
from .http_exceptions import HttpProcessingError
@@ -43,6 +45,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
4345
self._read_timeout_handle: Optional[asyncio.TimerHandle] = None
4446

4547
self._timeout_ceil_threshold: Optional[float] = 5
48+
self.closed: asyncio.Future[None] = self._loop.create_future()
4649

4750
@property
4851
def upgraded(self) -> bool:
@@ -83,6 +86,18 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
8386

8487
connection_closed_cleanly = original_connection_error is None
8588

89+
if connection_closed_cleanly:
90+
set_result(self.closed, None)
91+
else:
92+
assert original_connection_error is not None
93+
set_exception(
94+
self.closed,
95+
ClientConnectionError(
96+
f"Connection lost: {original_connection_error !s}",
97+
),
98+
original_connection_error,
99+
)
100+
86101
if self._payload_parser is not None:
87102
with suppress(Exception): # FIXME: log this somehow?
88103
self._payload_parser.feed_eof()

aiohttp/connector.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import functools
3+
import logging
34
import random
45
import socket
56
import sys
@@ -118,6 +119,14 @@ def __del__(self) -> None:
118119
)
119120

120121

122+
async def _wait_for_close(waiters: List[Awaitable[object]]) -> None:
123+
"""Wait for all waiters to finish closing."""
124+
results = await asyncio.gather(*waiters, return_exceptions=True)
125+
for res in results:
126+
if isinstance(res, Exception):
127+
logging.error("Error while closing connector: %r", res)
128+
129+
121130
class Connection:
122131

123132
_source_traceback = None
@@ -209,10 +218,14 @@ def closed(self) -> bool:
209218
class _TransportPlaceholder:
210219
"""placeholder for BaseConnector.connect function"""
211220

212-
__slots__ = ()
221+
__slots__ = ("closed",)
222+
223+
def __init__(self, closed_future: asyncio.Future[Optional[Exception]]) -> None:
224+
"""Initialize a placeholder for a transport."""
225+
self.closed = closed_future
213226

214227
def close(self) -> None:
215-
"""Close the placeholder transport."""
228+
"""Close the placeholder."""
216229

217230

218231
class BaseConnector:
@@ -309,6 +322,10 @@ def __init__(
309322

310323
self._cleanup_closed_disabled = not enable_cleanup_closed
311324
self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = []
325+
self._placeholder_future: asyncio.Future[Optional[Exception]] = (
326+
loop.create_future()
327+
)
328+
self._placeholder_future.set_result(None)
312329
self._cleanup_closed()
313330

314331
def __del__(self, _warnings: Any = warnings) -> None:
@@ -441,18 +458,30 @@ def _cleanup_closed(self) -> None:
441458

442459
def close(self) -> Awaitable[None]:
443460
"""Close all opened transports."""
444-
self._close()
445-
return _DeprecationWaiter(noop())
461+
if not (waiters := self._close()):
462+
# If there are no connections to close, we can return a noop
463+
# awaitable to avoid scheduling a task on the event loop.
464+
return _DeprecationWaiter(noop())
465+
coro = _wait_for_close(waiters)
466+
if sys.version_info >= (3, 12):
467+
# Optimization for Python 3.12, try to close connections
468+
# immediately to avoid having to schedule the task on the event loop.
469+
task = asyncio.Task(coro, loop=self._loop, eager_start=True)
470+
else:
471+
task = self._loop.create_task(coro)
472+
return _DeprecationWaiter(task)
473+
474+
def _close(self) -> List[Awaitable[object]]:
475+
waiters: List[Awaitable[object]] = []
446476

447-
def _close(self) -> None:
448477
if self._closed:
449-
return
478+
return waiters
450479

451480
self._closed = True
452481

453482
try:
454483
if self._loop.is_closed():
455-
return
484+
return waiters
456485

457486
# cancel cleanup task
458487
if self._cleanup_handle:
@@ -463,16 +492,20 @@ def _close(self) -> None:
463492
self._cleanup_closed_handle.cancel()
464493

465494
for data in self._conns.values():
466-
for proto, t0 in data:
495+
for proto, _ in data:
467496
proto.close()
497+
waiters.append(proto.closed)
468498

469499
for proto in self._acquired:
470500
proto.close()
501+
waiters.append(proto.closed)
471502

472503
for transport in self._cleanup_closed_transports:
473504
if transport is not None:
474505
transport.abort()
475506

507+
return waiters
508+
476509
finally:
477510
self._conns.clear()
478511
self._acquired.clear()
@@ -533,7 +566,9 @@ async def connect(
533566
if (conn := await self._get(key, traces)) is not None:
534567
return conn
535568

536-
placeholder = cast(ResponseHandler, _TransportPlaceholder())
569+
placeholder = cast(
570+
ResponseHandler, _TransportPlaceholder(self._placeholder_future)
571+
)
537572
self._acquired.add(placeholder)
538573
if self._limit_per_host:
539574
self._acquired_per_host[key].add(placeholder)
@@ -880,15 +915,18 @@ def __init__(
880915
self._interleave = interleave
881916
self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
882917

883-
def close(self) -> Awaitable[None]:
918+
def _close(self) -> List[Awaitable[object]]:
884919
"""Close all ongoing DNS calls."""
885920
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
886921
fut.cancel()
887922

923+
waiters = super()._close()
924+
888925
for t in self._resolve_host_tasks:
889926
t.cancel()
927+
waiters.append(t)
890928

891-
return super().close()
929+
return waiters
892930

893931
@property
894932
def family(self) -> int:

tests/test_client_request.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def protocol(loop, transport):
6262
protocol.transport = transport
6363
protocol._drain_helper.return_value = loop.create_future()
6464
protocol._drain_helper.return_value.set_result(None)
65+
protocol.closed = loop.create_future()
66+
protocol.closed.set_result(None)
6567
return protocol
6668

6769

@@ -1325,7 +1327,10 @@ async def send(self, conn):
13251327

13261328
async def create_connection(req, traces, timeout):
13271329
assert isinstance(req, CustomRequest)
1328-
return mock.Mock()
1330+
proto = mock.Mock()
1331+
proto.closed = loop.create_future()
1332+
proto.closed.set_result(None)
1333+
return proto
13291334

13301335
connector = BaseConnector(loop=loop)
13311336
connector._create_connection = create_connection

tests/test_client_session.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ async def make_conn():
3333

3434
conn = loop.run_until_complete(make_conn())
3535
proto = mock.Mock()
36+
proto.closed = loop.create_future()
37+
proto.closed.set_result(None)
3638
conn._conns["a"] = deque([(proto, 123)])
3739
yield conn
3840
loop.run_until_complete(conn.close())
@@ -427,7 +429,10 @@ async def test_reraise_os_error(create_session) -> None:
427429

428430
async def create_connection(req, traces, timeout):
429431
# return self.transport, self.protocol
430-
return mock.Mock()
432+
proto = mock.Mock()
433+
proto.closed = session._loop.create_future()
434+
proto.closed.set_result(None)
435+
return proto
431436

432437
session._connector._create_connection = create_connection
433438
session._connector._release = mock.Mock()
@@ -460,6 +465,8 @@ async def connect(req, traces, timeout):
460465
async def create_connection(req, traces, timeout):
461466
# return self.transport, self.protocol
462467
conn = mock.Mock()
468+
conn.closed = session._loop.create_future()
469+
conn.closed.set_result(None)
463470
return conn
464471

465472
session._connector.connect = connect

0 commit comments

Comments
 (0)