16
16
17
17
from __future__ import annotations
18
18
19
+ import asyncio
19
20
import logging
20
21
import os
21
22
import queue
29
30
30
31
from pymongo import _csot , common , helpers_shared , periodic_executor
31
32
from pymongo .asynchronous .client_session import _ServerSession , _ServerSessionPool
32
- from pymongo .asynchronous .monitor import SrvMonitor
33
+ from pymongo .asynchronous .monitor import MonitorBase , SrvMonitor
33
34
from pymongo .asynchronous .pool import Pool
34
35
from pymongo .asynchronous .server import Server
35
36
from pymongo .errors import (
@@ -207,6 +208,9 @@ async def target() -> bool:
207
208
if self ._settings .fqdn is not None and not self ._settings .load_balanced :
208
209
self ._srv_monitor = SrvMonitor (self , self ._settings )
209
210
211
+ # Stores all monitor tasks that need to be joined on close or server selection
212
+ self ._monitor_tasks : list [MonitorBase ] = []
213
+
210
214
async def open (self ) -> None :
211
215
"""Start monitoring, or restart after a fork.
212
216
@@ -241,6 +245,8 @@ async def open(self) -> None:
241
245
# Close servers and clear the pools.
242
246
for server in self ._servers .values ():
243
247
await server .close ()
248
+ if not _IS_SYNC :
249
+ self ._monitor_tasks .append (server ._monitor )
244
250
# Reset the session pool to avoid duplicate sessions in
245
251
# the child process.
246
252
self ._session_pool .reset ()
@@ -283,6 +289,10 @@ async def select_servers(
283
289
else :
284
290
server_timeout = server_selection_timeout
285
291
292
+ # Cleanup any completed monitor tasks safely
293
+ if not _IS_SYNC and self ._monitor_tasks :
294
+ await self .cleanup_monitors ()
295
+
286
296
async with self ._lock :
287
297
server_descriptions = await self ._select_servers_loop (
288
298
selector , server_timeout , operation , operation_id , address
@@ -520,6 +530,8 @@ async def _process_change(
520
530
and self ._description .topology_type not in SRV_POLLING_TOPOLOGIES
521
531
):
522
532
await self ._srv_monitor .close ()
533
+ if not _IS_SYNC :
534
+ self ._monitor_tasks .append (self ._srv_monitor )
523
535
524
536
# Clear the pool from a failed heartbeat.
525
537
if reset_pool :
@@ -695,6 +707,8 @@ async def close(self) -> None:
695
707
old_td = self ._description
696
708
for server in self ._servers .values ():
697
709
await server .close ()
710
+ if not _IS_SYNC :
711
+ self ._monitor_tasks .append (server ._monitor )
698
712
699
713
# Mark all servers Unknown.
700
714
self ._description = self ._description .reset ()
@@ -705,6 +719,8 @@ async def close(self) -> None:
705
719
# Stop SRV polling thread.
706
720
if self ._srv_monitor :
707
721
await self ._srv_monitor .close ()
722
+ if not _IS_SYNC :
723
+ self ._monitor_tasks .append (self ._srv_monitor )
708
724
709
725
self ._opened = False
710
726
self ._closed = True
@@ -944,6 +960,8 @@ async def _update_servers(self) -> None:
944
960
for address , server in list (self ._servers .items ()):
945
961
if not self ._description .has_server (address ):
946
962
await server .close ()
963
+ if not _IS_SYNC :
964
+ self ._monitor_tasks .append (server ._monitor )
947
965
self ._servers .pop (address )
948
966
949
967
def _create_pool_for_server (self , address : _Address ) -> Pool :
@@ -1031,6 +1049,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str:
1031
1049
else :
1032
1050
return "," .join (str (server .error ) for server in servers if server .error )
1033
1051
1052
+ async def cleanup_monitors (self ) -> None :
1053
+ tasks = []
1054
+ try :
1055
+ while self ._monitor_tasks :
1056
+ tasks .append (self ._monitor_tasks .pop ())
1057
+ except IndexError :
1058
+ pass
1059
+ await asyncio .gather (* [t .join () for t in tasks ], return_exceptions = True ) # type: ignore[func-returns-value]
1060
+
1034
1061
def __repr__ (self ) -> str :
1035
1062
msg = ""
1036
1063
if not self ._opened :
0 commit comments