Skip to content

Commit 0abffd5

Browse files
[PR #11074/e550c78a backport][3.12] Fix connector not waiting for connections to close (#11077)
Co-authored-by: J. Nick Koston <[email protected]> fixes #1925 fixes #3736
1 parent 2002b9d commit 0abffd5

File tree

7 files changed

+130
-34
lines changed

7 files changed

+130
-34
lines changed

CHANGES/11074.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed connector not waiting for connections to close before returning from :meth:`~aiohttp.BaseConnector.close` (partial backport of :pr:`3733`) -- by :user:`atemate` and :user:`bdraco`.

CHANGES/1925.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
11074.bugfix.rst

aiohttp/client_proto.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .base_protocol import BaseProtocol
66
from .client_exceptions import (
7+
ClientConnectionError,
78
ClientOSError,
89
ClientPayloadError,
910
ServerDisconnectedError,
@@ -14,6 +15,7 @@
1415
EMPTY_BODY_STATUS_CODES,
1516
BaseTimerContext,
1617
set_exception,
18+
set_result,
1719
)
1820
from .http import HttpResponseParser, RawResponseMessage
1921
from .http_exceptions import HttpProcessingError
@@ -43,6 +45,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
4345
self._read_timeout_handle: Optional[asyncio.TimerHandle] = None
4446

4547
self._timeout_ceil_threshold: Optional[float] = 5
48+
self.closed: asyncio.Future[None] = self._loop.create_future()
4649

4750
@property
4851
def upgraded(self) -> bool:
@@ -83,6 +86,18 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
8386

8487
connection_closed_cleanly = original_connection_error is None
8588

89+
if connection_closed_cleanly:
90+
set_result(self.closed, None)
91+
else:
92+
assert original_connection_error is not None
93+
set_exception(
94+
self.closed,
95+
ClientConnectionError(
96+
f"Connection lost: {original_connection_error !s}",
97+
),
98+
original_connection_error,
99+
)
100+
86101
if self._payload_parser is not None:
87102
with suppress(Exception): # FIXME: log this somehow?
88103
self._payload_parser.feed_eof()

aiohttp/connector.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import functools
3+
import logging
34
import random
45
import socket
56
import sys
@@ -131,6 +132,14 @@ def __del__(self) -> None:
131132
)
132133

133134

135+
async def _wait_for_close(waiters: List[Awaitable[object]]) -> None:
136+
"""Wait for all waiters to finish closing."""
137+
results = await asyncio.gather(*waiters, return_exceptions=True)
138+
for res in results:
139+
if isinstance(res, Exception):
140+
logging.error("Error while closing connector: %r", res)
141+
142+
134143
class Connection:
135144

136145
_source_traceback = None
@@ -222,10 +231,14 @@ def closed(self) -> bool:
222231
class _TransportPlaceholder:
223232
"""placeholder for BaseConnector.connect function"""
224233

225-
__slots__ = ()
234+
__slots__ = ("closed",)
235+
236+
def __init__(self, closed_future: asyncio.Future[Optional[Exception]]) -> None:
237+
"""Initialize a placeholder for a transport."""
238+
self.closed = closed_future
226239

227240
def close(self) -> None:
228-
"""Close the placeholder transport."""
241+
"""Close the placeholder."""
229242

230243

231244
class BaseConnector:
@@ -322,6 +335,10 @@ def __init__(
322335

323336
self._cleanup_closed_disabled = not enable_cleanup_closed
324337
self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = []
338+
self._placeholder_future: asyncio.Future[Optional[Exception]] = (
339+
loop.create_future()
340+
)
341+
self._placeholder_future.set_result(None)
325342
self._cleanup_closed()
326343

327344
def __del__(self, _warnings: Any = warnings) -> None:
@@ -454,18 +471,30 @@ def _cleanup_closed(self) -> None:
454471

455472
def close(self) -> Awaitable[None]:
456473
"""Close all opened transports."""
457-
self._close()
458-
return _DeprecationWaiter(noop())
474+
if not (waiters := self._close()):
475+
# If there are no connections to close, we can return a noop
476+
# awaitable to avoid scheduling a task on the event loop.
477+
return _DeprecationWaiter(noop())
478+
coro = _wait_for_close(waiters)
479+
if sys.version_info >= (3, 12):
480+
# Optimization for Python 3.12, try to close connections
481+
# immediately to avoid having to schedule the task on the event loop.
482+
task = asyncio.Task(coro, loop=self._loop, eager_start=True)
483+
else:
484+
task = self._loop.create_task(coro)
485+
return _DeprecationWaiter(task)
486+
487+
def _close(self) -> List[Awaitable[object]]:
488+
waiters: List[Awaitable[object]] = []
459489

460-
def _close(self) -> None:
461490
if self._closed:
462-
return
491+
return waiters
463492

464493
self._closed = True
465494

466495
try:
467496
if self._loop.is_closed():
468-
return
497+
return waiters
469498

470499
# cancel cleanup task
471500
if self._cleanup_handle:
@@ -476,16 +505,20 @@ def _close(self) -> None:
476505
self._cleanup_closed_handle.cancel()
477506

478507
for data in self._conns.values():
479-
for proto, t0 in data:
508+
for proto, _ in data:
480509
proto.close()
510+
waiters.append(proto.closed)
481511

482512
for proto in self._acquired:
483513
proto.close()
514+
waiters.append(proto.closed)
484515

485516
for transport in self._cleanup_closed_transports:
486517
if transport is not None:
487518
transport.abort()
488519

520+
return waiters
521+
489522
finally:
490523
self._conns.clear()
491524
self._acquired.clear()
@@ -546,7 +579,9 @@ async def connect(
546579
if (conn := await self._get(key, traces)) is not None:
547580
return conn
548581

549-
placeholder = cast(ResponseHandler, _TransportPlaceholder())
582+
placeholder = cast(
583+
ResponseHandler, _TransportPlaceholder(self._placeholder_future)
584+
)
550585
self._acquired.add(placeholder)
551586
if self._limit_per_host:
552587
self._acquired_per_host[key].add(placeholder)
@@ -898,15 +933,18 @@ def __init__(
898933
self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
899934
self._socket_factory = socket_factory
900935

901-
def close(self) -> Awaitable[None]:
936+
def _close(self) -> List[Awaitable[object]]:
902937
"""Close all ongoing DNS calls."""
903938
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
904939
fut.cancel()
905940

941+
waiters = super()._close()
942+
906943
for t in self._resolve_host_tasks:
907944
t.cancel()
945+
waiters.append(t)
908946

909-
return super().close()
947+
return waiters
910948

911949
@property
912950
def family(self) -> int:

tests/test_client_request.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def protocol(loop, transport):
6969
protocol.transport = transport
7070
protocol._drain_helper.return_value = loop.create_future()
7171
protocol._drain_helper.return_value.set_result(None)
72+
protocol.closed = loop.create_future()
73+
protocol.closed.set_result(None)
7274
return protocol
7375

7476

@@ -1404,7 +1406,10 @@ async def send(self, conn):
14041406

14051407
async def create_connection(req, traces, timeout):
14061408
assert isinstance(req, CustomRequest)
1407-
return mock.Mock()
1409+
proto = mock.Mock()
1410+
proto.closed = loop.create_future()
1411+
proto.closed.set_result(None)
1412+
return proto
14081413

14091414
connector = BaseConnector(loop=loop)
14101415
connector._create_connection = create_connection

tests/test_client_session.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ async def make_conn():
3333

3434
conn = loop.run_until_complete(make_conn())
3535
proto = mock.Mock()
36+
proto.closed = loop.create_future()
37+
proto.closed.set_result(None)
3638
conn._conns["a"] = deque([(proto, 123)])
3739
yield conn
3840
loop.run_until_complete(conn.close())
@@ -429,7 +431,10 @@ async def test_reraise_os_error(create_session) -> None:
429431

430432
async def create_connection(req, traces, timeout):
431433
# return self.transport, self.protocol
432-
return mock.Mock()
434+
proto = mock.Mock()
435+
proto.closed = session._loop.create_future()
436+
proto.closed.set_result(None)
437+
return proto
433438

434439
session._connector._create_connection = create_connection
435440
session._connector._release = mock.Mock()
@@ -464,6 +469,8 @@ async def connect(req, traces, timeout):
464469
async def create_connection(req, traces, timeout):
465470
# return self.transport, self.protocol
466471
conn = mock.Mock()
472+
conn.closed = session._loop.create_future()
473+
conn.closed.set_result(None)
467474
return conn
468475

469476
session._connector.connect = connect

0 commit comments

Comments
 (0)