diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index f5369b7c4c..43517e1fa7 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -32,6 +32,7 @@ """ from __future__ import annotations +import asyncio import contextlib import os import warnings @@ -2036,6 +2037,8 @@ async def _process_kill_cursors(self) -> None: for address, cursor_id, conn_mgr in pinned_cursors: try: await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it @@ -2050,6 +2053,8 @@ async def _process_kill_cursors(self) -> None: for address, cursor_ids in address_to_cursor_ids.items(): try: await self._kill_cursors(cursor_ids, address, topology, session=None) + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: raise @@ -2064,6 +2069,8 @@ async def _process_periodic_tasks(self) -> None: try: await self._process_kill_cursors() await self._topology.update_pool() + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: return diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index bbfd6a2998..780704fabb 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -16,6 +16,7 @@ from __future__ import annotations +import asyncio import atexit import logging import time @@ -26,7 +27,7 @@ from pymongo._csot import MovingMinimum from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled from pymongo.hello import Hello -from pymongo.lock import _create_lock +from pymongo.lock import _async_create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage from pymongo.periodic_executor import _shutdown_executors from pymongo.pool_options import _is_faas @@ -276,7 +277,7 @@ async def _check_server(self) -> ServerDescription: await self._reset_connection() if isinstance(error, _OperationCancelled): raise - self._rtt_monitor.reset() + await self._rtt_monitor.reset() # Server type defaults to Unknown. return ServerDescription(address, error=error) @@ -315,9 +316,9 @@ async def _check_once(self) -> ServerDescription: self._cancel_context = conn.cancel_context response, round_trip_time = await self._check_with_socket(conn) if not response.awaitable: - self._rtt_monitor.add_sample(round_trip_time) + await self._rtt_monitor.add_sample(round_trip_time) - avg_rtt, min_rtt = self._rtt_monitor.get() + avg_rtt, min_rtt = await self._rtt_monitor.get() sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt) if self._publish: assert self._listeners is not None @@ -413,6 +414,8 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception + except asyncio.CancelledError: + raise except Exception: # As per the spec, upon encountering an error: # - An error must not be raised @@ -441,7 +444,7 @@ def __init__(self, topology: Topology, topology_settings: TopologySettings, pool self._pool = pool self._moving_average = MovingAverage() self._moving_min = MovingMinimum() - self._lock = _create_lock() + self._lock = _async_create_lock() async def close(self) -> None: self.gc_safe_close() @@ -449,20 +452,20 @@ async def close(self) -> None: # thread has the socket checked out, it will be closed when checked in. await self._pool.reset() - def add_sample(self, sample: float) -> None: + async def add_sample(self, sample: float) -> None: """Add a RTT sample.""" - with self._lock: + async with self._lock: self._moving_average.add_sample(sample) self._moving_min.add_sample(sample) - def get(self) -> tuple[Optional[float], float]: + async def get(self) -> tuple[Optional[float], float]: """Get the calculated average, or None if no samples yet and the min.""" - with self._lock: + async with self._lock: return self._moving_average.get(), self._moving_min.get() - def reset(self) -> None: + async def reset(self) -> None: """Reset the average RTT.""" - with self._lock: + async with self._lock: self._moving_average.reset() self._moving_min.reset() @@ -472,10 +475,12 @@ async def _run(self) -> None: # heartbeat protocol (MongoDB 4.4+). # XXX: Skip check if the server is unknown? rtt = await self._ping() - self.add_sample(rtt) + await self.add_sample(rtt) except ReferenceError: # Topology was garbage-collected. await self.close() + except asyncio.CancelledError: + raise except Exception: await self._pool.reset() diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 2fe9579aef..a37aa3b46a 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -704,6 +704,8 @@ def _close_conn(self) -> None: # shutdown. try: self.conn.close() + except asyncio.CancelledError: + raise except Exception: # noqa: S110 pass diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index aa16e85a07..377689047b 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -271,7 +271,8 @@ async def async_receive_data( ) for task in pending: task.cancel() - await asyncio.wait(pending) + if pending: + await asyncio.wait(pending) if len(done) == 0: raise socket.timeout("timed out") if read_task in done: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index badcffe081..055688ab65 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -32,6 +32,7 @@ """ from __future__ import annotations +import asyncio import contextlib import os import warnings @@ -2030,6 +2031,8 @@ def _process_kill_cursors(self) -> None: for address, cursor_id, conn_mgr in pinned_cursors: try: self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it @@ -2044,6 +2047,8 @@ def _process_kill_cursors(self) -> None: for address, cursor_ids in address_to_cursor_ids.items(): try: self._kill_cursors(cursor_ids, address, topology, session=None) + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: raise @@ -2058,6 +2063,8 @@ def _process_periodic_tasks(self) -> None: try: self._process_kill_cursors() self._topology.update_pool() + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: return diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index a806670f2c..d82a5f976d 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -16,6 +16,7 @@ from __future__ import annotations +import asyncio import atexit import logging import time @@ -413,6 +414,8 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception + except asyncio.CancelledError: + raise except Exception: # As per the spec, upon encountering an error: # - An error must not be raised @@ -476,6 +479,8 @@ def _run(self) -> None: except ReferenceError: # Topology was garbage-collected. self.close() + except asyncio.CancelledError: + raise except Exception: self._pool.reset() diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 6ac7b4eca9..99201b822e 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -702,6 +702,8 @@ def _close_conn(self) -> None: # shutdown. try: self.conn.close() + except asyncio.CancelledError: + raise except Exception: # noqa: S110 pass diff --git a/test/__init__.py b/test/__init__.py index c1944f5870..dba3312424 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -17,6 +17,7 @@ import asyncio import gc +import logging import multiprocessing import os import signal @@ -25,6 +26,7 @@ import sys import threading import time +import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -191,6 +193,8 @@ def _connect(self, host, port, **kwargs): client.close() def _init_client(self): + self.mongoses = [] + self.connection_attempts = [] self.client = self._connect(host, port) if self.client is not None: # Return early when connected to dataLake as mongohoused does not diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 9ca5a32ffc..bed49de161 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -17,6 +17,7 @@ import asyncio import gc +import logging import multiprocessing import os import signal @@ -25,6 +26,7 @@ import sys import threading import time +import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -191,6 +193,8 @@ async def _connect(self, host, port, **kwargs): await client.close() async def _init_client(self): + self.mongoses = [] + self.connection_attempts = [] self.client = await self._connect(host, port) if self.client is not None: # Return early when connected to dataLake as mongohoused does not diff --git a/test/asynchronous/test_auth_spec.py b/test/asynchronous/test_auth_spec.py index a6ab1cb331..e9e43d5759 100644 --- a/test/asynchronous/test_auth_spec.py +++ b/test/asynchronous/test_auth_spec.py @@ -25,7 +25,7 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes from pymongo import AsyncMongoClient from pymongo.asynchronous.auth_oidc import OIDCCallback diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 873631bbe5..08da00cc1e 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -35,7 +35,7 @@ async_client_context, unittest, ) -from test.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, EventListener, diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 47cbff6d5b..292a78d645 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -2580,7 +2580,7 @@ async def test_direct_client_maintains_pool_to_arbiter(self): await async_wait_until(lambda: len(c.nodes) == 1, "connect") self.assertEqual(await c.address, ("c", 3)) # Assert that we create 1 pooled connection. - listener.wait_for_event(monitoring.ConnectionReadyEvent, 1) + await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1) self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1) arbiter = c._topology.get_server_by_address(("c", 3)) self.assertEqual(len(arbiter.pool.conns), 1) diff --git a/test/asynchronous/test_connection_logging.py b/test/asynchronous/test_connection_logging.py index 6bc9835b70..945c6c59b5 100644 --- a/test/asynchronous/test_connection_logging.py +++ b/test/asynchronous/test_connection_logging.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes _IS_SYNC = False diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py index ffff428379..bc9638b443 100644 --- a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -44,9 +44,6 @@ class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest): listener: CMAPListener coll: AsyncCollection - async def asyncTearDown(self): - await reset_client_context() - @async_client_context.require_replica_set async def asyncSetUp(self): self.listener = CMAPListener() diff --git a/test/asynchronous/test_create_entities.py b/test/asynchronous/test_create_entities.py index cb2ec63f4c..1f68cf6ddc 100644 --- a/test/asynchronous/test_create_entities.py +++ b/test/asynchronous/test_create_entities.py @@ -56,6 +56,9 @@ async def test_store_events_as_entities(self): self.assertGreater(len(final_entity_map["events1"]), 0) for event in final_entity_map["events1"]: self.assertIn("PoolCreatedEvent", event["name"]) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + await client.close() async def test_store_all_others_as_entities(self): self.scenario_runner = UnifiedSpecTestMixinV1() @@ -122,6 +125,9 @@ async def test_store_all_others_as_entities(self): self.assertEqual(entity_map["failures"], []) self.assertEqual(entity_map["successes"], 2) self.assertEqual(entity_map["iterations"], 5) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + await client.close() if __name__ == "__main__": diff --git a/test/asynchronous/test_crud_unified.py b/test/asynchronous/test_crud_unified.py index 3d8deb36e9..e6f42d5bdf 100644 --- a/test/asynchronous/test_crud_unified.py +++ b/test/asynchronous/test_crud_unified.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes _IS_SYNC = False diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 9c88f6ff20..a34741c144 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -46,6 +46,7 @@ unittest, ) from test.asynchronous.test_bulk import AsyncBulkTestBase +from test.asynchronous.unified_format import generate_test_classes from test.asynchronous.utils_spec_runner import AsyncSpecRunner from test.helpers import ( AWS_CREDS, @@ -56,7 +57,6 @@ KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, OvertCommandListener, diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 189e69a921..b2c889c461 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -304,7 +304,6 @@ async def _create_entity(self, entity_spec, uri=None): kwargs["h"] = uri client = await self.test.async_rs_or_single_client(**kwargs) self[spec["id"]] = client - self.test.addAsyncCleanup(client.close) return elif entity_type == "database": client = self[spec["client"]] @@ -499,6 +498,10 @@ async def asyncSetUp(self): # process file-level runOnRequirements run_on_spec = self.TEST_SPEC.get("runOnRequirements", []) if not await self.should_run_on(run_on_spec): + # Explicitly close async clients here + # to prevent leaky monitor tasks + if not _IS_SYNC: + await async_client_context.client.close() raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied") # add any special-casing for skipping tests here @@ -530,11 +533,6 @@ async def asyncSetUp(self): # initialize internals self.match_evaluator = MatchEvaluatorUtil(self) - async def asyncTearDown(self): - for client in self.mongos_clients: - await client.close() - await super().asyncTearDown() - def maybe_skip_test(self, spec): # add any special-casing for skipping tests here if async_client_context.storage_engine == "mmapv1": @@ -1042,7 +1040,6 @@ async def _testOperation_targetedFailPoint(self, spec): ) client = await self.async_single_client("{}:{}".format(*session._pinned_address)) - self.addAsyncCleanup(client.close) await self.__set_fail_point(client=client, command_args=spec["failPoint"]) async def _testOperation_createEntities(self, spec): diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index f0463244d7..75aa50b578 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -264,8 +264,6 @@ async def asyncSetUp(self) -> None: async def asyncTearDown(self) -> None: self.knobs.disable() - for client in self.mongos_clients: - await client.close() async def _set_fail_point(self, client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 4387850a00..84ef6decd5 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -44,9 +44,6 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): listener: CMAPListener coll: Collection - def tearDown(self): - reset_client_context() - @client_context.require_replica_set def setUp(self): self.listener = CMAPListener() diff --git a/test/test_create_entities.py b/test/test_create_entities.py index ad75fe5702..9d77a08eee 100644 --- a/test/test_create_entities.py +++ b/test/test_create_entities.py @@ -56,6 +56,9 @@ def test_store_events_as_entities(self): self.assertGreater(len(final_entity_map["events1"]), 0) for event in final_entity_map["events1"]: self.assertIn("PoolCreatedEvent", event["name"]) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + client.close() def test_store_all_others_as_entities(self): self.scenario_runner = UnifiedSpecTestMixinV1() @@ -122,6 +125,9 @@ def test_store_all_others_as_entities(self): self.assertEqual(entity_map["failures"], []) self.assertEqual(entity_map["successes"], 2) self.assertEqual(entity_map["iterations"], 5) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + client.close() if __name__ == "__main__": diff --git a/test/unified_format.py b/test/unified_format.py index 766489fb7c..7a38bd4b5e 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -303,7 +303,6 @@ def _create_entity(self, entity_spec, uri=None): kwargs["h"] = uri client = self.test.rs_or_single_client(**kwargs) self[spec["id"]] = client - self.test.addCleanup(client.close) return elif entity_type == "database": client = self[spec["client"]] @@ -498,6 +497,10 @@ def setUp(self): # process file-level runOnRequirements run_on_spec = self.TEST_SPEC.get("runOnRequirements", []) if not self.should_run_on(run_on_spec): + # Explicitly close async clients here + # to prevent leaky monitor tasks + if not _IS_SYNC: + client_context.client.close() raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied") # add any special-casing for skipping tests here @@ -529,11 +532,6 @@ def setUp(self): # initialize internals self.match_evaluator = MatchEvaluatorUtil(self) - def tearDown(self): - for client in self.mongos_clients: - client.close() - super().tearDown() - def maybe_skip_test(self, spec): # add any special-casing for skipping tests here if client_context.storage_engine == "mmapv1": @@ -1033,7 +1031,6 @@ def _testOperation_targetedFailPoint(self, spec): ) client = self.single_client("{}:{}".format(*session._pinned_address)) - self.addCleanup(client.close) self.__set_fail_point(client=client, command_args=spec["failPoint"]) def _testOperation_createEntities(self, spec): diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 682cf0b0f8..3dea4ede1c 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -264,8 +264,6 @@ def setUp(self) -> None: def tearDown(self) -> None: self.knobs.disable() - for client in self.mongos_clients: - client.close() def _set_fail_point(self, client, command_args): cmd = SON([("configureFailPoint", "failCommand")])