1
1
import asyncio
2
2
import functools
3
+ import logging
3
4
import random
4
5
import socket
5
6
import sys
@@ -118,6 +119,14 @@ def __del__(self) -> None:
118
119
)
119
120
120
121
122
+ async def _wait_for_close (waiters : List [Awaitable [object ]]) -> None :
123
+ """Wait for all waiters to finish closing."""
124
+ results = await asyncio .gather (* waiters , return_exceptions = True )
125
+ for res in results :
126
+ if isinstance (res , Exception ):
127
+ logging .error ("Error while closing connector: %r" , res )
128
+
129
+
121
130
class Connection :
122
131
123
132
_source_traceback = None
@@ -209,10 +218,14 @@ def closed(self) -> bool:
209
218
class _TransportPlaceholder :
210
219
"""placeholder for BaseConnector.connect function"""
211
220
212
- __slots__ = ()
221
+ __slots__ = ("closed" ,)
222
+
223
+ def __init__ (self , closed_future : asyncio .Future [Optional [Exception ]]) -> None :
224
+ """Initialize a placeholder for a transport."""
225
+ self .closed = closed_future
213
226
214
227
def close (self ) -> None :
215
- """Close the placeholder transport ."""
228
+ """Close the placeholder."""
216
229
217
230
218
231
class BaseConnector :
@@ -309,6 +322,10 @@ def __init__(
309
322
310
323
self ._cleanup_closed_disabled = not enable_cleanup_closed
311
324
self ._cleanup_closed_transports : List [Optional [asyncio .Transport ]] = []
325
+ self ._placeholder_future : asyncio .Future [Optional [Exception ]] = (
326
+ loop .create_future ()
327
+ )
328
+ self ._placeholder_future .set_result (None )
312
329
self ._cleanup_closed ()
313
330
314
331
def __del__ (self , _warnings : Any = warnings ) -> None :
@@ -441,18 +458,30 @@ def _cleanup_closed(self) -> None:
441
458
442
459
def close (self ) -> Awaitable [None ]:
443
460
"""Close all opened transports."""
444
- self ._close ()
445
- return _DeprecationWaiter (noop ())
461
+ if not (waiters := self ._close ()):
462
+ # If there are no connections to close, we can return a noop
463
+ # awaitable to avoid scheduling a task on the event loop.
464
+ return _DeprecationWaiter (noop ())
465
+ coro = _wait_for_close (waiters )
466
+ if sys .version_info >= (3 , 12 ):
467
+ # Optimization for Python 3.12, try to close connections
468
+ # immediately to avoid having to schedule the task on the event loop.
469
+ task = asyncio .Task (coro , loop = self ._loop , eager_start = True )
470
+ else :
471
+ task = self ._loop .create_task (coro )
472
+ return _DeprecationWaiter (task )
473
+
474
+ def _close (self ) -> List [Awaitable [object ]]:
475
+ waiters : List [Awaitable [object ]] = []
446
476
447
- def _close (self ) -> None :
448
477
if self ._closed :
449
- return
478
+ return waiters
450
479
451
480
self ._closed = True
452
481
453
482
try :
454
483
if self ._loop .is_closed ():
455
- return
484
+ return waiters
456
485
457
486
# cancel cleanup task
458
487
if self ._cleanup_handle :
@@ -463,16 +492,20 @@ def _close(self) -> None:
463
492
self ._cleanup_closed_handle .cancel ()
464
493
465
494
for data in self ._conns .values ():
466
- for proto , t0 in data :
495
+ for proto , _ in data :
467
496
proto .close ()
497
+ waiters .append (proto .closed )
468
498
469
499
for proto in self ._acquired :
470
500
proto .close ()
501
+ waiters .append (proto .closed )
471
502
472
503
for transport in self ._cleanup_closed_transports :
473
504
if transport is not None :
474
505
transport .abort ()
475
506
507
+ return waiters
508
+
476
509
finally :
477
510
self ._conns .clear ()
478
511
self ._acquired .clear ()
@@ -533,7 +566,9 @@ async def connect(
533
566
if (conn := await self ._get (key , traces )) is not None :
534
567
return conn
535
568
536
- placeholder = cast (ResponseHandler , _TransportPlaceholder ())
569
+ placeholder = cast (
570
+ ResponseHandler , _TransportPlaceholder (self ._placeholder_future )
571
+ )
537
572
self ._acquired .add (placeholder )
538
573
if self ._limit_per_host :
539
574
self ._acquired_per_host [key ].add (placeholder )
@@ -880,15 +915,18 @@ def __init__(
880
915
self ._interleave = interleave
881
916
self ._resolve_host_tasks : Set ["asyncio.Task[List[ResolveResult]]" ] = set ()
882
917
883
- def close (self ) -> Awaitable [None ]:
918
+ def _close (self ) -> List [ Awaitable [object ] ]:
884
919
"""Close all ongoing DNS calls."""
885
920
for fut in chain .from_iterable (self ._throttle_dns_futures .values ()):
886
921
fut .cancel ()
887
922
923
+ waiters = super ()._close ()
924
+
888
925
for t in self ._resolve_host_tasks :
889
926
t .cancel ()
927
+ waiters .append (t )
890
928
891
- return super (). close ()
929
+ return waiters
892
930
893
931
@property
894
932
def family (self ) -> int :
0 commit comments