-
Notifications
You must be signed in to change notification settings - Fork 1.1k
PYTHON-5053 - AsyncMongoClient.close() should await all background tasks #2127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
68c4a6e
d14d8e8
6c6a32d
a0a85c5
861dbb5
2ac4ea3
24e96f0
a2eb4bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1565,6 +1565,10 @@ async def close(self) -> None: | |
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. | ||
await self._encrypter.close() | ||
self._closed = True | ||
if not _IS_SYNC: | ||
self._topology._monitor_tasks.append(self._kill_cursors_executor) # type: ignore[arg-type] | ||
join_tasks = [t.join() for t in self._topology._monitor_tasks] # type: ignore[func-returns-value] | ||
|
||
await asyncio.gather(*join_tasks) | ||
|
||
if not _IS_SYNC: | ||
# Add support for contextlib.aclosing. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -112,9 +112,9 @@ async def close(self) -> None: | |
""" | ||
self.gc_safe_close() | ||
|
||
async def join(self, timeout: Optional[int] = None) -> None: | ||
async def join(self) -> None: | ||
"""Wait for the monitor to stop.""" | ||
await self._executor.join(timeout) | ||
await self._executor.join() | ||
|
||
def request_check(self) -> None: | ||
"""If the monitor is sleeping, wake it soon.""" | ||
|
@@ -189,6 +189,10 @@ def gc_safe_close(self) -> None: | |
self._rtt_monitor.gc_safe_close() | ||
self.cancel_check() | ||
|
||
async def join(self) -> None: | ||
await self._executor.join() | ||
await self._rtt_monitor.join() | ||
|
||
|
||
async def close(self) -> None: | ||
self.gc_safe_close() | ||
await self._rtt_monitor.close() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
from __future__ import annotations | ||
|
||
import asyncio | ||
import logging | ||
import os | ||
import queue | ||
|
@@ -29,7 +30,7 @@ | |
|
||
from pymongo import _csot, common, helpers_shared, periodic_executor | ||
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool | ||
from pymongo.asynchronous.monitor import SrvMonitor | ||
from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor | ||
from pymongo.asynchronous.pool import Pool | ||
from pymongo.asynchronous.server import Server | ||
from pymongo.errors import ( | ||
|
@@ -207,6 +208,9 @@ async def target() -> bool: | |
if self._settings.fqdn is not None and not self._settings.load_balanced: | ||
self._srv_monitor = SrvMonitor(self, self._settings) | ||
|
||
# Stores all monitor tasks that need to be joined on close or server selection | ||
self._monitor_tasks: list[MonitorBase] = [] | ||
|
||
async def open(self) -> None: | ||
"""Start monitoring, or restart after a fork. | ||
|
||
|
@@ -241,6 +245,7 @@ async def open(self) -> None: | |
# Close servers and clear the pools. | ||
for server in self._servers.values(): | ||
await server.close() | ||
self._monitor_tasks.append(server._monitor) | ||
# Reset the session pool to avoid duplicate sessions in | ||
# the child process. | ||
self._session_pool.reset() | ||
|
@@ -288,10 +293,17 @@ async def select_servers( | |
selector, server_timeout, operation, operation_id, address | ||
) | ||
|
||
return [ | ||
servers = [ | ||
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions | ||
] | ||
|
||
if not _IS_SYNC and self._monitor_tasks: | ||
|
||
joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value] | ||
await asyncio.gather(*joins) | ||
self._monitor_tasks = [] | ||
|
||
return servers | ||
|
||
async def _select_servers_loop( | ||
self, | ||
selector: Callable[[Selection], Selection], | ||
|
@@ -520,6 +532,7 @@ async def _process_change( | |
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES | ||
): | ||
await self._srv_monitor.close() | ||
self._monitor_tasks.append(self._srv_monitor) | ||
|
||
# Clear the pool from a failed heartbeat. | ||
if reset_pool: | ||
|
@@ -695,6 +708,7 @@ async def close(self) -> None: | |
old_td = self._description | ||
for server in self._servers.values(): | ||
await server.close() | ||
self._monitor_tasks.append(server._monitor) | ||
|
||
# Mark all servers Unknown. | ||
self._description = self._description.reset() | ||
|
@@ -705,6 +719,7 @@ async def close(self) -> None: | |
# Stop SRV polling thread. | ||
if self._srv_monitor: | ||
await self._srv_monitor.close() | ||
self._monitor_tasks.append(self._srv_monitor) | ||
|
||
self._opened = False | ||
self._closed = True | ||
|
@@ -944,6 +959,7 @@ async def _update_servers(self) -> None: | |
for address, server in list(self._servers.items()): | ||
if not self._description.has_server(address): | ||
await server.close() | ||
self._monitor_tasks.append(server._monitor) | ||
self._servers.pop(address) | ||
|
||
def _create_pool_for_server(self, address: _Address) -> Pool: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
from __future__ import annotations | ||
|
||
import asyncio | ||
import logging | ||
import os | ||
import queue | ||
|
@@ -61,7 +62,7 @@ | |
writable_server_selector, | ||
) | ||
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool | ||
from pymongo.synchronous.monitor import SrvMonitor | ||
from pymongo.synchronous.monitor import MonitorBase, SrvMonitor | ||
from pymongo.synchronous.pool import Pool | ||
from pymongo.synchronous.server import Server | ||
from pymongo.topology_description import ( | ||
|
@@ -207,6 +208,9 @@ def target() -> bool: | |
if self._settings.fqdn is not None and not self._settings.load_balanced: | ||
self._srv_monitor = SrvMonitor(self, self._settings) | ||
|
||
# Stores all monitor tasks that need to be joined on close or server selection | ||
self._monitor_tasks: list[MonitorBase] = [] | ||
|
||
def open(self) -> None: | ||
"""Start monitoring, or restart after a fork. | ||
|
||
|
@@ -241,6 +245,7 @@ def open(self) -> None: | |
# Close servers and clear the pools. | ||
for server in self._servers.values(): | ||
server.close() | ||
self._monitor_tasks.append(server._monitor) | ||
|
||
# Reset the session pool to avoid duplicate sessions in | ||
# the child process. | ||
self._session_pool.reset() | ||
|
@@ -288,10 +293,17 @@ def select_servers( | |
selector, server_timeout, operation, operation_id, address | ||
) | ||
|
||
return [ | ||
servers = [ | ||
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions | ||
] | ||
|
||
if not _IS_SYNC and self._monitor_tasks: | ||
joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value] | ||
asyncio.gather(*joins) | ||
self._monitor_tasks = [] | ||
|
||
return servers | ||
|
||
def _select_servers_loop( | ||
self, | ||
selector: Callable[[Selection], Selection], | ||
|
@@ -520,6 +532,7 @@ def _process_change( | |
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES | ||
): | ||
self._srv_monitor.close() | ||
self._monitor_tasks.append(self._srv_monitor) | ||
|
||
# Clear the pool from a failed heartbeat. | ||
if reset_pool: | ||
|
@@ -693,6 +706,7 @@ def close(self) -> None: | |
old_td = self._description | ||
for server in self._servers.values(): | ||
server.close() | ||
self._monitor_tasks.append(server._monitor) | ||
|
||
# Mark all servers Unknown. | ||
self._description = self._description.reset() | ||
|
@@ -703,6 +717,7 @@ def close(self) -> None: | |
# Stop SRV polling thread. | ||
if self._srv_monitor: | ||
self._srv_monitor.close() | ||
self._monitor_tasks.append(self._srv_monitor) | ||
|
||
self._opened = False | ||
self._closed = True | ||
|
@@ -942,6 +957,7 @@ def _update_servers(self) -> None: | |
for address, server in list(self._servers.items()): | ||
if not self._description.has_server(address): | ||
server.close() | ||
self._monitor_tasks.append(server._monitor) | ||
self._servers.pop(address) | ||
|
||
def _create_pool_for_server(self, address: _Address) -> Pool: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid appending to the topology's private state here.