Skip to content

Commit 6c6a32d

Browse files
committed
Store tasks to be awaited inside Topology
1 parent d14d8e8 commit 6c6a32d

File tree

8 files changed

+52
-28
lines changed

8 files changed

+52
-28
lines changed

pymongo/asynchronous/mongo_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,14 +1559,16 @@ async def close(self) -> None:
15591559
# Stop the periodic task thread and then send pending killCursor
15601560
# requests before closing the topology.
15611561
self._kill_cursors_executor.close()
1562-
if not _IS_SYNC:
1563-
await self._kill_cursors_executor.join()
15641562
await self._process_kill_cursors()
15651563
await self._topology.close()
15661564
if self._encrypter:
15671565
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
15681566
await self._encrypter.close()
15691567
self._closed = True
1568+
if not _IS_SYNC:
1569+
self._topology._monitor_tasks.append(self._kill_cursors_executor) # type: ignore[arg-type]
1570+
join_tasks = [t.join() for t in self._topology._monitor_tasks] # type: ignore[func-returns-value]
1571+
await asyncio.gather(*join_tasks)
15701572

15711573
if not _IS_SYNC:
15721574
# Add support for contextlib.aclosing.

pymongo/asynchronous/monitor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ async def close(self) -> None:
112112
"""
113113
self.gc_safe_close()
114114

115-
async def join(self, timeout: Optional[int] = None) -> None:
115+
async def join(self) -> None:
116116
"""Wait for the monitor to stop."""
117-
await self._executor.join(timeout)
117+
await self._executor.join()
118118

119119
def request_check(self) -> None:
120120
"""If the monitor is sleeping, wake it soon."""
@@ -189,8 +189,8 @@ def gc_safe_close(self) -> None:
189189
self._rtt_monitor.gc_safe_close()
190190
self.cancel_check()
191191

192-
async def join(self, timeout: Optional[int] = None) -> None:
193-
await self._executor.join(timeout)
192+
async def join(self) -> None:
193+
await self._executor.join()
194194
await self._rtt_monitor.join()
195195

196196
async def close(self) -> None:

pymongo/asynchronous/server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,6 @@ async def close(self) -> None:
115115
)
116116

117117
await self._monitor.close()
118-
if not _IS_SYNC:
119-
await self._monitor.join()
120118
await self._pool.close()
121119

122120
def request_check(self) -> None:

pymongo/asynchronous/topology.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import asyncio
1920
import logging
2021
import os
2122
import queue
@@ -29,7 +30,7 @@
2930

3031
from pymongo import _csot, common, helpers_shared, periodic_executor
3132
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
32-
from pymongo.asynchronous.monitor import SrvMonitor
33+
from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor
3334
from pymongo.asynchronous.pool import Pool
3435
from pymongo.asynchronous.server import Server
3536
from pymongo.errors import (
@@ -207,6 +208,9 @@ async def target() -> bool:
207208
if self._settings.fqdn is not None and not self._settings.load_balanced:
208209
self._srv_monitor = SrvMonitor(self, self._settings)
209210

211+
# Stores all monitor tasks that need to be joined on close or server selection
212+
self._monitor_tasks: list[MonitorBase] = []
213+
210214
async def open(self) -> None:
211215
"""Start monitoring, or restart after a fork.
212216
@@ -241,6 +245,7 @@ async def open(self) -> None:
241245
# Close servers and clear the pools.
242246
for server in self._servers.values():
243247
await server.close()
248+
self._monitor_tasks.append(server._monitor)
244249
# Reset the session pool to avoid duplicate sessions in
245250
# the child process.
246251
self._session_pool.reset()
@@ -288,10 +293,17 @@ async def select_servers(
288293
selector, server_timeout, operation, operation_id, address
289294
)
290295

291-
return [
296+
servers = [
292297
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions
293298
]
294299

300+
if not _IS_SYNC and self._monitor_tasks:
301+
joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value]
302+
await asyncio.gather(*joins)
303+
self._monitor_tasks = []
304+
305+
return servers
306+
295307
async def _select_servers_loop(
296308
self,
297309
selector: Callable[[Selection], Selection],
@@ -520,8 +532,7 @@ async def _process_change(
520532
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
521533
):
522534
await self._srv_monitor.close()
523-
if not _IS_SYNC:
524-
await self._srv_monitor.join()
535+
self._monitor_tasks.append(self._srv_monitor)
525536

526537
# Clear the pool from a failed heartbeat.
527538
if reset_pool:
@@ -697,6 +708,7 @@ async def close(self) -> None:
697708
old_td = self._description
698709
for server in self._servers.values():
699710
await server.close()
711+
self._monitor_tasks.append(server._monitor)
700712

701713
# Mark all servers Unknown.
702714
self._description = self._description.reset()
@@ -707,8 +719,7 @@ async def close(self) -> None:
707719
# Stop SRV polling thread.
708720
if self._srv_monitor:
709721
await self._srv_monitor.close()
710-
if not _IS_SYNC:
711-
await self._srv_monitor.join()
722+
self._monitor_tasks.append(self._srv_monitor)
712723

713724
self._opened = False
714725
self._closed = True
@@ -948,6 +959,7 @@ async def _update_servers(self) -> None:
948959
for address, server in list(self._servers.items()):
949960
if not self._description.has_server(address):
950961
await server.close()
962+
self._monitor_tasks.append(server._monitor)
951963
self._servers.pop(address)
952964

953965
def _create_pool_for_server(self, address: _Address) -> Pool:

pymongo/synchronous/mongo_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,14 +1553,16 @@ def close(self) -> None:
15531553
# Stop the periodic task thread and then send pending killCursor
15541554
# requests before closing the topology.
15551555
self._kill_cursors_executor.close()
1556-
if not _IS_SYNC:
1557-
self._kill_cursors_executor.join()
15581556
self._process_kill_cursors()
15591557
self._topology.close()
15601558
if self._encrypter:
15611559
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
15621560
self._encrypter.close()
15631561
self._closed = True
1562+
if not _IS_SYNC:
1563+
self._topology._monitor_tasks.append(self._kill_cursors_executor) # type: ignore[arg-type]
1564+
join_tasks = [t.join() for t in self._topology._monitor_tasks] # type: ignore[func-returns-value]
1565+
asyncio.gather(*join_tasks)
15641566

15651567
if not _IS_SYNC:
15661568
# Add support for contextlib.closing.

pymongo/synchronous/monitor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def close(self) -> None:
112112
"""
113113
self.gc_safe_close()
114114

115-
def join(self, timeout: Optional[int] = None) -> None:
115+
def join(self) -> None:
116116
"""Wait for the monitor to stop."""
117-
self._executor.join(timeout)
117+
self._executor.join()
118118

119119
def request_check(self) -> None:
120120
"""If the monitor is sleeping, wake it soon."""
@@ -189,8 +189,8 @@ def gc_safe_close(self) -> None:
189189
self._rtt_monitor.gc_safe_close()
190190
self.cancel_check()
191191

192-
def join(self, timeout: Optional[int] = None) -> None:
193-
self._executor.join(timeout)
192+
def join(self) -> None:
193+
self._executor.join()
194194
self._rtt_monitor.join()
195195

196196
def close(self) -> None:

pymongo/synchronous/server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,6 @@ def close(self) -> None:
115115
)
116116

117117
self._monitor.close()
118-
if not _IS_SYNC:
119-
self._monitor.join()
120118
self._pool.close()
121119

122120
def request_check(self) -> None:

pymongo/synchronous/topology.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import asyncio
1920
import logging
2021
import os
2122
import queue
@@ -61,7 +62,7 @@
6162
writable_server_selector,
6263
)
6364
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
64-
from pymongo.synchronous.monitor import SrvMonitor
65+
from pymongo.synchronous.monitor import MonitorBase, SrvMonitor
6566
from pymongo.synchronous.pool import Pool
6667
from pymongo.synchronous.server import Server
6768
from pymongo.topology_description import (
@@ -207,6 +208,9 @@ def target() -> bool:
207208
if self._settings.fqdn is not None and not self._settings.load_balanced:
208209
self._srv_monitor = SrvMonitor(self, self._settings)
209210

211+
# Stores all monitor tasks that need to be joined on close or server selection
212+
self._monitor_tasks: list[MonitorBase] = []
213+
210214
def open(self) -> None:
211215
"""Start monitoring, or restart after a fork.
212216
@@ -241,6 +245,7 @@ def open(self) -> None:
241245
# Close servers and clear the pools.
242246
for server in self._servers.values():
243247
server.close()
248+
self._monitor_tasks.append(server._monitor)
244249
# Reset the session pool to avoid duplicate sessions in
245250
# the child process.
246251
self._session_pool.reset()
@@ -288,10 +293,17 @@ def select_servers(
288293
selector, server_timeout, operation, operation_id, address
289294
)
290295

291-
return [
296+
servers = [
292297
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions
293298
]
294299

300+
if not _IS_SYNC and self._monitor_tasks:
301+
joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value]
302+
asyncio.gather(*joins)
303+
self._monitor_tasks = []
304+
305+
return servers
306+
295307
def _select_servers_loop(
296308
self,
297309
selector: Callable[[Selection], Selection],
@@ -520,8 +532,7 @@ def _process_change(
520532
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
521533
):
522534
self._srv_monitor.close()
523-
if not _IS_SYNC:
524-
self._srv_monitor.join()
535+
self._monitor_tasks.append(self._srv_monitor)
525536

526537
# Clear the pool from a failed heartbeat.
527538
if reset_pool:
@@ -695,6 +706,7 @@ def close(self) -> None:
695706
old_td = self._description
696707
for server in self._servers.values():
697708
server.close()
709+
self._monitor_tasks.append(server._monitor)
698710

699711
# Mark all servers Unknown.
700712
self._description = self._description.reset()
@@ -705,8 +717,7 @@ def close(self) -> None:
705717
# Stop SRV polling thread.
706718
if self._srv_monitor:
707719
self._srv_monitor.close()
708-
if not _IS_SYNC:
709-
self._srv_monitor.join()
720+
self._monitor_tasks.append(self._srv_monitor)
710721

711722
self._opened = False
712723
self._closed = True
@@ -946,6 +957,7 @@ def _update_servers(self) -> None:
946957
for address, server in list(self._servers.items()):
947958
if not self._description.has_server(address):
948959
server.close()
960+
self._monitor_tasks.append(server._monitor)
949961
self._servers.pop(address)
950962

951963
def _create_pool_for_server(self, address: _Address) -> Pool:

0 commit comments

Comments
 (0)