Skip to content

Commit 37ef0fb

Browse files
authored
[PR #9671/37d9fe6 backport][3.11] Refactor connection waiters to be cancellation safe (#9675)
1 parent fb72954 commit 37ef0fb

File tree

4 files changed

+401
-197
lines changed

4 files changed

+401
-197
lines changed

CHANGES/9670.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9671.bugfix.rst

CHANGES/9671.bugfix.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fixed a deadlock that could occur while attempting to get a new connection slot after a timeout -- by :user:`bdraco`.
2+
3+
The connector was not cancellation-safe.

aiohttp/connector.py

Lines changed: 99 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
import traceback
77
import warnings
8-
from collections import defaultdict, deque
8+
from collections import OrderedDict, defaultdict
99
from contextlib import suppress
1010
from http import HTTPStatus
1111
from itertools import chain, cycle, islice
@@ -266,9 +266,11 @@ def __init__(
266266
self._force_close = force_close
267267

268268
# {host_key: FIFO list of waiters}
269-
self._waiters: DefaultDict[ConnectionKey, deque[asyncio.Future[None]]] = (
270-
defaultdict(deque)
271-
)
269+
# The FIFO is implemented with an OrderedDict with None keys because
270+
# python does not have an ordered set.
271+
self._waiters: DefaultDict[
272+
ConnectionKey, OrderedDict[asyncio.Future[None], None]
273+
] = defaultdict(OrderedDict)
272274

273275
self._loop = loop
274276
self._factory = functools.partial(ResponseHandler, loop=loop)
@@ -356,7 +358,7 @@ def _cleanup(self) -> None:
356358
# recreate it ever!
357359
self._cleanup_handle = None
358360

359-
now = self._loop.time()
361+
now = monotonic()
360362
timeout = self._keepalive_timeout
361363

362364
if self._conns:
@@ -387,14 +389,6 @@ def _cleanup(self) -> None:
387389
timeout_ceil_threshold=self._timeout_ceil_threshold,
388390
)
389391

390-
def _drop_acquired_per_host(
391-
self, key: "ConnectionKey", val: ResponseHandler
392-
) -> None:
393-
if conns := self._acquired_per_host.get(key):
394-
conns.remove(val)
395-
if not conns:
396-
del self._acquired_per_host[key]
397-
398392
def _cleanup_closed(self) -> None:
399393
"""Double confirmation for transport close.
400394
@@ -455,6 +449,9 @@ def _close(self) -> None:
455449
finally:
456450
self._conns.clear()
457451
self._acquired.clear()
452+
for keyed_waiters in self._waiters.values():
453+
for keyed_waiter in keyed_waiters:
454+
keyed_waiter.cancel()
458455
self._waiters.clear()
459456
self._cleanup_handle = None
460457
self._cleanup_closed_transports.clear()
@@ -498,113 +495,107 @@ async def connect(
498495
) -> Connection:
499496
"""Get from pool or create new connection."""
500497
key = req.connection_key
501-
available = self._available_connections(key)
502-
wait_for_conn = available <= 0 or key in self._waiters
503-
if not wait_for_conn and (proto := self._get(key)) is not None:
498+
if (conn := await self._get(key, traces)) is not None:
504499
# If we do not have to wait and we can get a connection from the pool
505500
# we can avoid the timeout ceil logic and directly return the connection
506-
return await self._reused_connection(key, proto, traces)
501+
return conn
507502

508503
async with ceil_timeout(timeout.connect, timeout.ceil_threshold):
509-
# Wait if there are no available connections or if there are/were
510-
# waiters (i.e. don't steal connection from a waiter about to wake up)
511-
if wait_for_conn:
504+
if self._available_connections(key) <= 0:
512505
await self._wait_for_available_connection(key, traces)
513-
if (proto := self._get(key)) is not None:
514-
return await self._reused_connection(key, proto, traces)
506+
if (conn := await self._get(key, traces)) is not None:
507+
return conn
515508

516509
placeholder = cast(ResponseHandler, _TransportPlaceholder())
517510
self._acquired.add(placeholder)
518511
self._acquired_per_host[key].add(placeholder)
519512

520-
if traces:
521-
for trace in traces:
522-
await trace.send_connection_create_start()
523-
524513
try:
514+
# Traces are done inside the try block to ensure that the
515+
# that the placeholder is still cleaned up if an exception
516+
# is raised.
517+
if traces:
518+
for trace in traces:
519+
await trace.send_connection_create_start()
525520
proto = await self._create_connection(req, traces, timeout)
526-
if self._closed:
527-
proto.close()
528-
raise ClientConnectionError("Connector is closed.")
521+
if traces:
522+
for trace in traces:
523+
await trace.send_connection_create_end()
529524
except BaseException:
530-
if not self._closed:
531-
self._acquired.remove(placeholder)
532-
self._drop_acquired_per_host(key, placeholder)
533-
self._release_waiter()
525+
self._release_acquired(key, placeholder)
534526
raise
535527
else:
536-
if not self._closed:
537-
self._acquired.remove(placeholder)
538-
self._drop_acquired_per_host(key, placeholder)
539-
540-
if traces:
541-
for trace in traces:
542-
await trace.send_connection_create_end()
543-
544-
return self._acquired_connection(proto, key)
545-
546-
async def _reused_connection(
547-
self, key: "ConnectionKey", proto: ResponseHandler, traces: List["Trace"]
548-
) -> Connection:
549-
if traces:
550-
# Acquire the connection to prevent race conditions with limits
551-
placeholder = cast(ResponseHandler, _TransportPlaceholder())
552-
self._acquired.add(placeholder)
553-
self._acquired_per_host[key].add(placeholder)
554-
for trace in traces:
555-
await trace.send_connection_reuseconn()
556-
self._acquired.remove(placeholder)
557-
self._drop_acquired_per_host(key, placeholder)
558-
return self._acquired_connection(proto, key)
528+
if self._closed:
529+
proto.close()
530+
raise ClientConnectionError("Connector is closed.")
559531

560-
def _acquired_connection(
561-
self, proto: ResponseHandler, key: "ConnectionKey"
562-
) -> Connection:
563-
"""Mark proto as acquired and wrap it in a Connection object."""
532+
# The connection was successfully created, drop the placeholder
533+
# and add the real connection to the acquired set. There should
534+
# be no awaits after the proto is added to the acquired set
535+
# to ensure that the connection is not left in the acquired set
536+
# on cancellation.
537+
acquired_per_host = self._acquired_per_host[key]
538+
self._acquired.remove(placeholder)
539+
acquired_per_host.remove(placeholder)
564540
self._acquired.add(proto)
565-
self._acquired_per_host[key].add(proto)
541+
acquired_per_host.add(proto)
566542
return Connection(self, key, proto, self._loop)
567543

568544
async def _wait_for_available_connection(
569545
self, key: "ConnectionKey", traces: List["Trace"]
570546
) -> None:
571-
"""Wait until there is an available connection."""
572-
fut: asyncio.Future[None] = self._loop.create_future()
573-
574-
# This connection will now count towards the limit.
575-
self._waiters[key].append(fut)
547+
"""Wait for an available connection slot."""
548+
# We loop here because there is a race between
549+
# the connection limit check and the connection
550+
# being acquired. If the connection is acquired
551+
# between the check and the await statement, we
552+
# need to loop again to check if the connection
553+
# slot is still available.
554+
attempts = 0
555+
while True:
556+
fut: asyncio.Future[None] = self._loop.create_future()
557+
keyed_waiters = self._waiters[key]
558+
keyed_waiters[fut] = None
559+
if attempts:
560+
# If we have waited before, we need to move the waiter
561+
# to the front of the queue as otherwise we might get
562+
# starved and hit the timeout.
563+
keyed_waiters.move_to_end(fut, last=False)
576564

577-
if traces:
578-
for trace in traces:
579-
await trace.send_connection_queued_start()
565+
try:
566+
# Traces happen in the try block to ensure that the
567+
# the waiter is still cleaned up if an exception is raised.
568+
if traces:
569+
for trace in traces:
570+
await trace.send_connection_queued_start()
571+
await fut
572+
if traces:
573+
for trace in traces:
574+
await trace.send_connection_queued_end()
575+
finally:
576+
# pop the waiter from the queue if its still
577+
# there and not already removed by _release_waiter
578+
keyed_waiters.pop(fut, None)
579+
if not self._waiters.get(key, True):
580+
del self._waiters[key]
580581

581-
try:
582-
await fut
583-
except BaseException as e:
584-
if key in self._waiters:
585-
# remove a waiter even if it was cancelled, normally it's
586-
# removed when it's notified
587-
with suppress(ValueError):
588-
# fut may no longer be in list
589-
self._waiters[key].remove(fut)
590-
591-
raise e
592-
finally:
593-
if key in self._waiters and not self._waiters[key]:
594-
del self._waiters[key]
582+
if self._available_connections(key) > 0:
583+
break
584+
attempts += 1
595585

596-
if traces:
597-
for trace in traces:
598-
await trace.send_connection_queued_end()
586+
async def _get(
587+
self, key: "ConnectionKey", traces: List["Trace"]
588+
) -> Optional[Connection]:
589+
"""Get next reusable connection for the key or None.
599590
600-
def _get(self, key: "ConnectionKey") -> Optional[ResponseHandler]:
601-
"""Get next reusable connection for the key or None."""
591+
The connection will be marked as acquired.
592+
"""
602593
try:
603594
conns = self._conns[key]
604595
except KeyError:
605596
return None
606597

607-
t1 = self._loop.time()
598+
t1 = monotonic()
608599
while conns:
609600
proto, t0 = conns.pop()
610601
# We will we reuse the connection if its connected and
@@ -613,7 +604,16 @@ def _get(self, key: "ConnectionKey") -> Optional[ResponseHandler]:
613604
if not conns:
614605
# The very last connection was reclaimed: drop the key
615606
del self._conns[key]
616-
return proto
607+
self._acquired.add(proto)
608+
self._acquired_per_host[key].add(proto)
609+
if traces:
610+
for trace in traces:
611+
try:
612+
await trace.send_connection_reuseconn()
613+
except BaseException:
614+
self._release_acquired(key, proto)
615+
raise
616+
return Connection(self, key, proto, self._loop)
617617

618618
# Connection cannot be reused, close it
619619
transport = proto.transport
@@ -647,25 +647,23 @@ def _release_waiter(self) -> None:
647647

648648
waiters = self._waiters[key]
649649
while waiters:
650-
waiter = waiters.popleft()
650+
waiter, _ = waiters.popitem(last=False)
651651
if not waiter.done():
652652
waiter.set_result(None)
653653
return
654654

655655
def _release_acquired(self, key: "ConnectionKey", proto: ResponseHandler) -> None:
656+
"""Release acquired connection."""
656657
if self._closed:
657658
# acquired connection is already released on connector closing
658659
return
659660

660-
try:
661-
self._acquired.remove(proto)
662-
self._drop_acquired_per_host(key, proto)
663-
except KeyError: # pragma: no cover
664-
# this may be result of undetermenistic order of objects
665-
# finalization due garbage collection.
666-
pass
667-
else:
668-
self._release_waiter()
661+
self._acquired.discard(proto)
662+
if conns := self._acquired_per_host.get(key):
663+
conns.discard(proto)
664+
if not conns:
665+
del self._acquired_per_host[key]
666+
self._release_waiter()
669667

670668
def _release(
671669
self,
@@ -694,7 +692,7 @@ def _release(
694692
conns = self._conns.get(key)
695693
if conns is None:
696694
conns = self._conns[key] = []
697-
conns.append((protocol, self._loop.time()))
695+
conns.append((protocol, monotonic()))
698696

699697
if self._cleanup_handle is None:
700698
self._cleanup_handle = helpers.weakref_handle(

0 commit comments

Comments
 (0)