Skip to content

PYTHON-5053 - AsyncMongoClient.close() should await all background tasks #2127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,10 @@ async def close(self) -> None:
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
await self._encrypter.close()
self._closed = True
if not _IS_SYNC:
self._topology._monitor_tasks.append(self._kill_cursors_executor) # type: ignore[arg-type]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid appending to the topology's private state here.

join_tasks = [t.join() for t in self._topology._monitor_tasks] # type: ignore[func-returns-value]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading _monitor_tasks like this is not thread safe. Anytime we iterate it we need to guard against the list being mutated from another thread, something like:

tasks = []
try:
    while self._topology._monitor_tasks:
        tasks.append(self._topology._monitor_tasks.pop())
except IndexError:
    pass
    

Copy link
Contributor Author

@NoahStapp NoahStapp Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we supporting multithreaded async workloads? My understanding was that we are explicitly not supporting such use cases and assume that all AsyncMongoClient operations will take place on a single thread.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is this a futureproofing suggestion for when we do the same joining process for synchronous tasks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made this change in the interest of covering our bases and reducing future changes for the sync API.

Copy link
Member

@ShaneHarvey ShaneHarvey Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, I let my sync brain bleed into the async code. Yeah async is single threaded so it's safe to iterate the list as along as there no yield points.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any objection to doing your suggested change anyway for the reasons I stated above?

await asyncio.gather(*join_tasks)

if not _IS_SYNC:
# Add support for contextlib.aclosing.
Expand Down
8 changes: 6 additions & 2 deletions pymongo/asynchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ async def close(self) -> None:
"""
self.gc_safe_close()

async def join(self, timeout: Optional[int] = None) -> None:
async def join(self) -> None:
"""Wait for the monitor to stop."""
await self._executor.join(timeout)
await self._executor.join()

def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
Expand Down Expand Up @@ -189,6 +189,10 @@ def gc_safe_close(self) -> None:
self._rtt_monitor.gc_safe_close()
self.cancel_check()

async def join(self) -> None:
await self._executor.join()
await self._rtt_monitor.join()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use gather too right?


async def close(self) -> None:
self.gc_safe_close()
await self._rtt_monitor.close()
Expand Down
20 changes: 18 additions & 2 deletions pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import asyncio
import logging
import os
import queue
Expand All @@ -29,7 +30,7 @@

from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.asynchronous.monitor import SrvMonitor
from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor
from pymongo.asynchronous.pool import Pool
from pymongo.asynchronous.server import Server
from pymongo.errors import (
Expand Down Expand Up @@ -207,6 +208,9 @@ async def target() -> bool:
if self._settings.fqdn is not None and not self._settings.load_balanced:
self._srv_monitor = SrvMonitor(self, self._settings)

# Stores all monitor tasks that need to be joined on close or server selection
self._monitor_tasks: list[MonitorBase] = []

async def open(self) -> None:
"""Start monitoring, or restart after a fork.

Expand Down Expand Up @@ -241,6 +245,7 @@ async def open(self) -> None:
# Close servers and clear the pools.
for server in self._servers.values():
await server.close()
self._monitor_tasks.append(server._monitor)
# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()
Expand Down Expand Up @@ -288,10 +293,17 @@ async def select_servers(
selector, server_timeout, operation, operation_id, address
)

return [
servers = [
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions
]

if not _IS_SYNC and self._monitor_tasks:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth putting a comment here to explain why this code is here. Also this should happen before selecting the server. Doing it after will increase the risk of returning stale information.

Copy link
Contributor Author

@NoahStapp NoahStapp Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The risk being the delay added by the cleanup between selecting the server and actually returning it? Makes sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's it.

joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value]
await asyncio.gather(*joins)
self._monitor_tasks = []

return servers

async def _select_servers_loop(
self,
selector: Callable[[Selection], Selection],
Expand Down Expand Up @@ -520,6 +532,7 @@ async def _process_change(
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
):
await self._srv_monitor.close()
self._monitor_tasks.append(self._srv_monitor)

# Clear the pool from a failed heartbeat.
if reset_pool:
Expand Down Expand Up @@ -695,6 +708,7 @@ async def close(self) -> None:
old_td = self._description
for server in self._servers.values():
await server.close()
self._monitor_tasks.append(server._monitor)

# Mark all servers Unknown.
self._description = self._description.reset()
Expand All @@ -705,6 +719,7 @@ async def close(self) -> None:
# Stop SRV polling thread.
if self._srv_monitor:
await self._srv_monitor.close()
self._monitor_tasks.append(self._srv_monitor)

self._opened = False
self._closed = True
Expand Down Expand Up @@ -944,6 +959,7 @@ async def _update_servers(self) -> None:
for address, server in list(self._servers.items()):
if not self._description.has_server(address):
await server.close()
self._monitor_tasks.append(server._monitor)
self._servers.pop(address)

def _create_pool_for_server(self, address: _Address) -> Pool:
Expand Down
2 changes: 2 additions & 0 deletions pymongo/periodic_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def close(self, dummy: Any = None) -> None:
callback; see monitor.py.
"""
self._stopped = True
if self._task is not None:
self._task.cancel()

async def join(self, timeout: Optional[int] = None) -> None:
if self._task is not None:
Expand Down
4 changes: 4 additions & 0 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,10 @@ def close(self) -> None:
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
self._encrypter.close()
self._closed = True
if not _IS_SYNC:
self._topology._monitor_tasks.append(self._kill_cursors_executor) # type: ignore[arg-type]
join_tasks = [t.join() for t in self._topology._monitor_tasks] # type: ignore[func-returns-value]
asyncio.gather(*join_tasks)

if not _IS_SYNC:
# Add support for contextlib.closing.
Expand Down
8 changes: 6 additions & 2 deletions pymongo/synchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def close(self) -> None:
"""
self.gc_safe_close()

def join(self, timeout: Optional[int] = None) -> None:
def join(self) -> None:
"""Wait for the monitor to stop."""
self._executor.join(timeout)
self._executor.join()

def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
Expand Down Expand Up @@ -189,6 +189,10 @@ def gc_safe_close(self) -> None:
self._rtt_monitor.gc_safe_close()
self.cancel_check()

def join(self) -> None:
self._executor.join()
self._rtt_monitor.join()

def close(self) -> None:
self.gc_safe_close()
self._rtt_monitor.close()
Expand Down
20 changes: 18 additions & 2 deletions pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import asyncio
import logging
import os
import queue
Expand Down Expand Up @@ -61,7 +62,7 @@
writable_server_selector,
)
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.synchronous.monitor import SrvMonitor
from pymongo.synchronous.monitor import MonitorBase, SrvMonitor
from pymongo.synchronous.pool import Pool
from pymongo.synchronous.server import Server
from pymongo.topology_description import (
Expand Down Expand Up @@ -207,6 +208,9 @@ def target() -> bool:
if self._settings.fqdn is not None and not self._settings.load_balanced:
self._srv_monitor = SrvMonitor(self, self._settings)

# Stores all monitor tasks that need to be joined on close or server selection
self._monitor_tasks: list[MonitorBase] = []

def open(self) -> None:
"""Start monitoring, or restart after a fork.

Expand Down Expand Up @@ -241,6 +245,7 @@ def open(self) -> None:
# Close servers and clear the pools.
for server in self._servers.values():
server.close()
self._monitor_tasks.append(server._monitor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we only want to record these tasks on async. Otherwise we'll have an unbounded list of threads in the sync version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, missed this one.

# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()
Expand Down Expand Up @@ -288,10 +293,17 @@ def select_servers(
selector, server_timeout, operation, operation_id, address
)

return [
servers = [
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions
]

if not _IS_SYNC and self._monitor_tasks:
joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value]
asyncio.gather(*joins)
self._monitor_tasks = []

return servers

def _select_servers_loop(
self,
selector: Callable[[Selection], Selection],
Expand Down Expand Up @@ -520,6 +532,7 @@ def _process_change(
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
):
self._srv_monitor.close()
self._monitor_tasks.append(self._srv_monitor)

# Clear the pool from a failed heartbeat.
if reset_pool:
Expand Down Expand Up @@ -693,6 +706,7 @@ def close(self) -> None:
old_td = self._description
for server in self._servers.values():
server.close()
self._monitor_tasks.append(server._monitor)

# Mark all servers Unknown.
self._description = self._description.reset()
Expand All @@ -703,6 +717,7 @@ def close(self) -> None:
# Stop SRV polling thread.
if self._srv_monitor:
self._srv_monitor.close()
self._monitor_tasks.append(self._srv_monitor)

self._opened = False
self._closed = True
Expand Down Expand Up @@ -942,6 +957,7 @@ def _update_servers(self) -> None:
for address, server in list(self._servers.items()):
if not self._description.has_server(address):
server.close()
self._monitor_tasks.append(server._monitor)
self._servers.pop(address)

def _create_pool_for_server(self, address: _Address) -> Pool:
Expand Down
Loading