diff --git a/test/asynchronous/test_heartbeat_monitoring.py b/test/asynchronous/test_heartbeat_monitoring.py new file mode 100644 index 0000000000..ff595a8144 --- /dev/null +++ b/test/asynchronous/test_heartbeat_monitoring.py @@ -0,0 +1,97 @@ +# Copyright 2016-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the monitoring of the server heartbeats.""" +from __future__ import annotations + +import sys + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest +from test.utils import AsyncMockPool, HeartbeatEventListener, async_wait_until + +from pymongo.asynchronous.monitor import Monitor +from pymongo.errors import ConnectionFailure +from pymongo.hello import Hello, HelloCompat + +_IS_SYNC = False + + +class TestHeartbeatMonitoring(AsyncIntegrationTest): + async def create_mock_monitor(self, responses, uri, expected_results): + listener = HeartbeatEventListener() + with client_knobs( + heartbeat_frequency=0.1, min_heartbeat_interval=0.1, events_queue_frequency=0.1 + ): + + class MockMonitor(Monitor): + async def _check_with_socket(self, *args, **kwargs): + if isinstance(responses[1], Exception): + raise responses[1] + return Hello(responses[1]), 99 + + _ = await self.async_single_client( + h=uri, + event_listeners=(listener,), + _monitor_class=MockMonitor, + _pool_class=AsyncMockPool, + connect=True, + ) + + expected_len = len(expected_results) + # Wait for *at least* expected_len number of results. The + # monitor thread may run multiple times during the execution + # of this test. + await async_wait_until( + lambda: len(listener.events) >= expected_len, "publish all events" + ) + + # zip gives us len(expected_results) pairs. + for expected, actual in zip(expected_results, listener.events): + self.assertEqual(expected, actual.__class__.__name__) + self.assertEqual(actual.connection_id, responses[0]) + if expected != "ServerHeartbeatStartedEvent": + if isinstance(actual.reply, Hello): + self.assertEqual(actual.duration, 99) + self.assertEqual(actual.reply._doc, responses[1]) + else: + self.assertEqual(actual.reply, responses[1]) + + async def test_standalone(self): + responses = ( + ("a", 27017), + {HelloCompat.LEGACY_CMD: True, "maxWireVersion": 4, "minWireVersion": 0, "ok": 1}, + ) + uri = "mongodb://a:27017" + expected_results = ["ServerHeartbeatStartedEvent", "ServerHeartbeatSucceededEvent"] + + await self.create_mock_monitor(responses, uri, expected_results) + + async def test_standalone_error(self): + responses = (("a", 27017), ConnectionFailure("SPECIAL MESSAGE")) + uri = "mongodb://a:27017" + # _check_with_socket failing results in a second attempt. + expected_results = [ + "ServerHeartbeatStartedEvent", + "ServerHeartbeatFailedEvent", + "ServerHeartbeatStartedEvent", + "ServerHeartbeatFailedEvent", + ] + + await self.create_mock_monitor(responses, uri, expected_results) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_client.py b/test/test_client.py index 2a33077f5f..cdc7691c28 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -2399,7 +2399,7 @@ def test_reconnect(self): # MongoClient discovers it's alone. The first attempt raises either # ServerSelectionTimeoutError or AutoReconnect (from - # AsyncMockPool.get_socket). + # MockPool.get_socket). with self.assertRaises(AutoReconnect): c.db.collection.find_one() diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index 5e203a33b3..0523d0ba4d 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -26,6 +26,8 @@ from pymongo.hello import Hello, HelloCompat from pymongo.synchronous.monitor import Monitor +_IS_SYNC = True + class TestHeartbeatMonitoring(IntegrationTest): def create_mock_monitor(self, responses, uri, expected_results): @@ -40,8 +42,12 @@ def _check_with_socket(self, *args, **kwargs): raise responses[1] return Hello(responses[1]), 99 - m = self.single_client( - h=uri, event_listeners=(listener,), _monitor_class=MockMonitor, _pool_class=MockPool + _ = self.single_client( + h=uri, + event_listeners=(listener,), + _monitor_class=MockMonitor, + _pool_class=MockPool, + connect=True, ) expected_len = len(expected_results) @@ -50,20 +56,16 @@ def _check_with_socket(self, *args, **kwargs): # of this test. wait_until(lambda: len(listener.events) >= expected_len, "publish all events") - try: - # zip gives us len(expected_results) pairs. - for expected, actual in zip(expected_results, listener.events): - self.assertEqual(expected, actual.__class__.__name__) - self.assertEqual(actual.connection_id, responses[0]) - if expected != "ServerHeartbeatStartedEvent": - if isinstance(actual.reply, Hello): - self.assertEqual(actual.duration, 99) - self.assertEqual(actual.reply._doc, responses[1]) - else: - self.assertEqual(actual.reply, responses[1]) - - finally: - m.close() + # zip gives us len(expected_results) pairs. + for expected, actual in zip(expected_results, listener.events): + self.assertEqual(expected, actual.__class__.__name__) + self.assertEqual(actual.connection_id, responses[0]) + if expected != "ServerHeartbeatStartedEvent": + if isinstance(actual.reply, Hello): + self.assertEqual(actual.duration, 99) + self.assertEqual(actual.reply._doc, responses[1]) + else: + self.assertEqual(actual.reply, responses[1]) def test_standalone(self): responses = ( diff --git a/test/utils.py b/test/utils.py index 69154bc63b..91000a636a 100644 --- a/test/utils.py +++ b/test/utils.py @@ -43,7 +43,7 @@ from pymongo.errors import ConfigurationError, OperationFailure from pymongo.hello import HelloCompat from pymongo.helpers_shared import _SENSITIVE_COMMANDS -from pymongo.lock import _create_lock +from pymongo.lock import _async_create_lock, _create_lock from pymongo.monitoring import ( ConnectionCheckedInEvent, ConnectionCheckedOutEvent, @@ -312,6 +312,22 @@ def failed(self, event): self.event_list.append("serverHeartbeatFailedEvent") +class AsyncMockConnection: + def __init__(self): + self.cancel_context = _CancellationContext() + self.more_to_come = False + self.id = random.randint(0, 100) + + def close_conn(self, reason): + pass + + def __aenter__(self): + return self + + def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + class MockConnection: def __init__(self): self.cancel_context = _CancellationContext() @@ -328,6 +344,47 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass +class AsyncMockPool: + def __init__(self, address, options, handshake=True, client_id=None): + self.gen = _PoolGeneration() + self._lock = _async_create_lock() + self.opts = options + self.operation_count = 0 + self.conns = [] + + def stale_generation(self, gen, service_id): + return self.gen.stale(gen, service_id) + + @contextlib.asynccontextmanager + async def checkout(self, handler=None): + yield AsyncMockConnection() + + async def checkin(self, *args, **kwargs): + pass + + async def _reset(self, service_id=None): + async with self._lock: + self.gen.inc(service_id) + + async def ready(self): + pass + + async def reset(self, service_id=None, interrupt_connections=False): + await self._reset() + + async def reset_without_pause(self): + await self._reset() + + async def close(self): + await self._reset() + + async def update_is_writable(self, is_writable): + pass + + async def remove_stale_sockets(self, *args, **kwargs): + pass + + class MockPool: def __init__(self, address, options, handshake=True, client_id=None): self.gen = _PoolGeneration() diff --git a/tools/synchro.py b/tools/synchro.py index 08281c73d0..74b7c80533 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -119,6 +119,8 @@ "_async_create_lock": "_create_lock", "_async_create_condition": "_create_condition", "_async_cond_wait": "_cond_wait", + "AsyncMockConnection": "MockConnection", + "AsyncMockPool": "MockPool", } docstring_replacements: dict[tuple[str, str], str] = { @@ -206,6 +208,7 @@ def async_only_test(f: str) -> bool: "test_database.py", "test_data_lake.py", "test_encryption.py", + "test_heartbeat_monitoring.py", "test_index_management.py", "test_grid_file.py", "test_gridfs_spec.py",