-
Notifications
You must be signed in to change notification settings - Fork 1.1k
PYTHON-5053 - AsyncMongoClient.close() should await all background tasks #2127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
68c4a6e
d14d8e8
6c6a32d
a0a85c5
861dbb5
2ac4ea3
24e96f0
a2eb4bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1565,6 +1565,10 @@ async def close(self) -> None: | |
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. | ||
await self._encrypter.close() | ||
self._closed = True | ||
if not _IS_SYNC: | ||
self._topology._monitor_tasks.append(self._kill_cursors_executor) # type: ignore[arg-type] | ||
join_tasks = [t.join() for t in self._topology._monitor_tasks] # type: ignore[func-returns-value] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reading
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we supporting multithreaded async workloads? My understanding was that we are explicitly not supporting such use cases and assume that all There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or is this a futureproofing suggestion for when we do the same joining process for synchronous tasks? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made this change in the interest of covering our bases and reducing future changes for the sync API. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh sorry, I let my sync brain bleed into the async code. Yeah async is single threaded so it's safe to iterate the list as along as there no yield points. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any objection to doing your suggested change anyway for the reasons I stated above? |
||
await asyncio.gather(*join_tasks) | ||
|
||
if not _IS_SYNC: | ||
# Add support for contextlib.aclosing. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -112,9 +112,9 @@ async def close(self) -> None: | |
""" | ||
self.gc_safe_close() | ||
|
||
async def join(self, timeout: Optional[int] = None) -> None: | ||
async def join(self) -> None: | ||
"""Wait for the monitor to stop.""" | ||
await self._executor.join(timeout) | ||
await self._executor.join() | ||
|
||
def request_check(self) -> None: | ||
"""If the monitor is sleeping, wake it soon.""" | ||
|
@@ -189,6 +189,10 @@ def gc_safe_close(self) -> None: | |
self._rtt_monitor.gc_safe_close() | ||
self.cancel_check() | ||
|
||
async def join(self) -> None: | ||
await self._executor.join() | ||
await self._rtt_monitor.join() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should use gather too right? |
||
|
||
async def close(self) -> None: | ||
self.gc_safe_close() | ||
await self._rtt_monitor.close() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
from __future__ import annotations | ||
|
||
import asyncio | ||
import logging | ||
import os | ||
import queue | ||
|
@@ -29,7 +30,7 @@ | |
|
||
from pymongo import _csot, common, helpers_shared, periodic_executor | ||
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool | ||
from pymongo.asynchronous.monitor import SrvMonitor | ||
from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor | ||
from pymongo.asynchronous.pool import Pool | ||
from pymongo.asynchronous.server import Server | ||
from pymongo.errors import ( | ||
|
@@ -207,6 +208,9 @@ async def target() -> bool: | |
if self._settings.fqdn is not None and not self._settings.load_balanced: | ||
self._srv_monitor = SrvMonitor(self, self._settings) | ||
|
||
# Stores all monitor tasks that need to be joined on close or server selection | ||
self._monitor_tasks: list[MonitorBase] = [] | ||
|
||
async def open(self) -> None: | ||
"""Start monitoring, or restart after a fork. | ||
|
||
|
@@ -241,6 +245,7 @@ async def open(self) -> None: | |
# Close servers and clear the pools. | ||
for server in self._servers.values(): | ||
await server.close() | ||
self._monitor_tasks.append(server._monitor) | ||
# Reset the session pool to avoid duplicate sessions in | ||
# the child process. | ||
self._session_pool.reset() | ||
|
@@ -288,10 +293,17 @@ async def select_servers( | |
selector, server_timeout, operation, operation_id, address | ||
) | ||
|
||
return [ | ||
servers = [ | ||
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions | ||
] | ||
|
||
if not _IS_SYNC and self._monitor_tasks: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Worth putting a comment here to explain why this code is here. Also this should happen before selecting the server. Doing it after will increase the risk of returning stale information. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The risk being the delay added by the cleanup between selecting the server and actually returning it? Makes sense. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep that's it. |
||
joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value] | ||
await asyncio.gather(*joins) | ||
self._monitor_tasks = [] | ||
|
||
return servers | ||
|
||
async def _select_servers_loop( | ||
self, | ||
selector: Callable[[Selection], Selection], | ||
|
@@ -520,6 +532,7 @@ async def _process_change( | |
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES | ||
): | ||
await self._srv_monitor.close() | ||
self._monitor_tasks.append(self._srv_monitor) | ||
|
||
# Clear the pool from a failed heartbeat. | ||
if reset_pool: | ||
|
@@ -695,6 +708,7 @@ async def close(self) -> None: | |
old_td = self._description | ||
for server in self._servers.values(): | ||
await server.close() | ||
self._monitor_tasks.append(server._monitor) | ||
|
||
# Mark all servers Unknown. | ||
self._description = self._description.reset() | ||
|
@@ -705,6 +719,7 @@ async def close(self) -> None: | |
# Stop SRV polling thread. | ||
if self._srv_monitor: | ||
await self._srv_monitor.close() | ||
self._monitor_tasks.append(self._srv_monitor) | ||
|
||
self._opened = False | ||
self._closed = True | ||
|
@@ -944,6 +959,7 @@ async def _update_servers(self) -> None: | |
for address, server in list(self._servers.items()): | ||
if not self._description.has_server(address): | ||
await server.close() | ||
self._monitor_tasks.append(server._monitor) | ||
self._servers.pop(address) | ||
|
||
def _create_pool_for_server(self, address: _Address) -> Pool: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
from __future__ import annotations | ||
|
||
import asyncio | ||
import logging | ||
import os | ||
import queue | ||
|
@@ -61,7 +62,7 @@ | |
writable_server_selector, | ||
) | ||
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool | ||
from pymongo.synchronous.monitor import SrvMonitor | ||
from pymongo.synchronous.monitor import MonitorBase, SrvMonitor | ||
from pymongo.synchronous.pool import Pool | ||
from pymongo.synchronous.server import Server | ||
from pymongo.topology_description import ( | ||
|
@@ -207,6 +208,9 @@ def target() -> bool: | |
if self._settings.fqdn is not None and not self._settings.load_balanced: | ||
self._srv_monitor = SrvMonitor(self, self._settings) | ||
|
||
# Stores all monitor tasks that need to be joined on close or server selection | ||
self._monitor_tasks: list[MonitorBase] = [] | ||
|
||
def open(self) -> None: | ||
"""Start monitoring, or restart after a fork. | ||
|
||
|
@@ -241,6 +245,7 @@ def open(self) -> None: | |
# Close servers and clear the pools. | ||
for server in self._servers.values(): | ||
server.close() | ||
self._monitor_tasks.append(server._monitor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe we only want to record these tasks on async. Otherwise we'll have an unbounded list of threads in the sync version. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, missed this one. |
||
# Reset the session pool to avoid duplicate sessions in | ||
# the child process. | ||
self._session_pool.reset() | ||
|
@@ -288,10 +293,17 @@ def select_servers( | |
selector, server_timeout, operation, operation_id, address | ||
) | ||
|
||
return [ | ||
servers = [ | ||
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions | ||
] | ||
|
||
if not _IS_SYNC and self._monitor_tasks: | ||
joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value] | ||
asyncio.gather(*joins) | ||
self._monitor_tasks = [] | ||
|
||
return servers | ||
|
||
def _select_servers_loop( | ||
self, | ||
selector: Callable[[Selection], Selection], | ||
|
@@ -520,6 +532,7 @@ def _process_change( | |
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES | ||
): | ||
self._srv_monitor.close() | ||
self._monitor_tasks.append(self._srv_monitor) | ||
|
||
# Clear the pool from a failed heartbeat. | ||
if reset_pool: | ||
|
@@ -693,6 +706,7 @@ def close(self) -> None: | |
old_td = self._description | ||
for server in self._servers.values(): | ||
server.close() | ||
self._monitor_tasks.append(server._monitor) | ||
|
||
# Mark all servers Unknown. | ||
self._description = self._description.reset() | ||
|
@@ -703,6 +717,7 @@ def close(self) -> None: | |
# Stop SRV polling thread. | ||
if self._srv_monitor: | ||
self._srv_monitor.close() | ||
self._monitor_tasks.append(self._srv_monitor) | ||
|
||
self._opened = False | ||
self._closed = True | ||
|
@@ -942,6 +957,7 @@ def _update_servers(self) -> None: | |
for address, server in list(self._servers.items()): | ||
if not self._description.has_server(address): | ||
server.close() | ||
self._monitor_tasks.append(server._monitor) | ||
self._servers.pop(address) | ||
|
||
def _create_pool_for_server(self, address: _Address) -> Pool: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid appending to the topology's private state here.