1616
1717from __future__ import annotations
1818
19+ import asyncio
1920import atexit
2021import logging
2122import time
2627from pymongo ._csot import MovingMinimum
2728from pymongo .errors import NetworkTimeout , NotPrimaryError , OperationFailure , _OperationCancelled
2829from pymongo .hello import Hello
29- from pymongo .lock import _create_lock
30+ from pymongo .lock import _async_create_lock
3031from pymongo .logger import _SDAM_LOGGER , _debug_log , _SDAMStatusMessage
3132from pymongo .periodic_executor import _shutdown_executors
3233from pymongo .pool_options import _is_faas
@@ -276,7 +277,7 @@ async def _check_server(self) -> ServerDescription:
276277 await self ._reset_connection ()
277278 if isinstance (error , _OperationCancelled ):
278279 raise
279- self ._rtt_monitor .reset ()
280+ await self ._rtt_monitor .reset ()
280281 # Server type defaults to Unknown.
281282 return ServerDescription (address , error = error )
282283
@@ -315,9 +316,9 @@ async def _check_once(self) -> ServerDescription:
315316 self ._cancel_context = conn .cancel_context
316317 response , round_trip_time = await self ._check_with_socket (conn )
317318 if not response .awaitable :
318- self ._rtt_monitor .add_sample (round_trip_time )
319+ await self ._rtt_monitor .add_sample (round_trip_time )
319320
320- avg_rtt , min_rtt = self ._rtt_monitor .get ()
321+ avg_rtt , min_rtt = await self ._rtt_monitor .get ()
321322 sd = ServerDescription (address , response , avg_rtt , min_round_trip_time = min_rtt )
322323 if self ._publish :
323324 assert self ._listeners is not None
@@ -413,6 +414,8 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
413414 if len (seedlist ) == 0 :
414415 # As per the spec: this should be treated as a failure.
415416 raise Exception
417+ except asyncio .CancelledError :
418+ raise
416419 except Exception :
417420 # As per the spec, upon encountering an error:
418421 # - An error must not be raised
@@ -441,28 +444,28 @@ def __init__(self, topology: Topology, topology_settings: TopologySettings, pool
441444 self ._pool = pool
442445 self ._moving_average = MovingAverage ()
443446 self ._moving_min = MovingMinimum ()
444- self ._lock = _create_lock ()
447+ self ._lock = _async_create_lock ()
445448
446449 async def close (self ) -> None :
447450 self .gc_safe_close ()
448451 # Increment the generation and maybe close the socket. If the executor
449452 # thread has the socket checked out, it will be closed when checked in.
450453 await self ._pool .reset ()
451454
452- def add_sample (self , sample : float ) -> None :
455+ async def add_sample (self , sample : float ) -> None :
453456 """Add a RTT sample."""
454- with self ._lock :
457+ async with self ._lock :
455458 self ._moving_average .add_sample (sample )
456459 self ._moving_min .add_sample (sample )
457460
458- def get (self ) -> tuple [Optional [float ], float ]:
461+ async def get (self ) -> tuple [Optional [float ], float ]:
459462 """Get the calculated average, or None if no samples yet and the min."""
460- with self ._lock :
463+ async with self ._lock :
461464 return self ._moving_average .get (), self ._moving_min .get ()
462465
463- def reset (self ) -> None :
466+ async def reset (self ) -> None :
464467 """Reset the average RTT."""
465- with self ._lock :
468+ async with self ._lock :
466469 self ._moving_average .reset ()
467470 self ._moving_min .reset ()
468471
@@ -472,10 +475,12 @@ async def _run(self) -> None:
472475 # heartbeat protocol (MongoDB 4.4+).
473476 # XXX: Skip check if the server is unknown?
474477 rtt = await self ._ping ()
475- self .add_sample (rtt )
478+ await self .add_sample (rtt )
476479 except ReferenceError :
477480 # Topology was garbage-collected.
478481 await self .close ()
482+ except asyncio .CancelledError :
483+ raise
479484 except Exception :
480485 await self ._pool .reset ()
481486
0 commit comments