Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""
from __future__ import annotations

import asyncio
import contextlib
import os
import warnings
Expand Down Expand Up @@ -2036,6 +2037,8 @@ async def _process_kill_cursors(self) -> None:
for address, cursor_id, conn_mgr in pinned_cursors:
try:
await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it
Expand All @@ -2050,6 +2053,8 @@ async def _process_kill_cursors(self) -> None:
for address, cursor_ids in address_to_cursor_ids.items():
try:
await self._kill_cursors(cursor_ids, address, topology, session=None)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
raise
Expand All @@ -2064,6 +2069,8 @@ async def _process_periodic_tasks(self) -> None:
try:
await self._process_kill_cursors()
await self._topology.update_pool()
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
return
Expand Down
29 changes: 17 additions & 12 deletions pymongo/asynchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import asyncio
import atexit
import logging
import time
Expand All @@ -26,7 +27,7 @@
from pymongo._csot import MovingMinimum
from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled
from pymongo.hello import Hello
from pymongo.lock import _create_lock
from pymongo.lock import _async_create_lock
from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage
from pymongo.periodic_executor import _shutdown_executors
from pymongo.pool_options import _is_faas
Expand Down Expand Up @@ -276,7 +277,7 @@ async def _check_server(self) -> ServerDescription:
await self._reset_connection()
if isinstance(error, _OperationCancelled):
raise
self._rtt_monitor.reset()
await self._rtt_monitor.reset()
# Server type defaults to Unknown.
return ServerDescription(address, error=error)

Expand Down Expand Up @@ -315,9 +316,9 @@ async def _check_once(self) -> ServerDescription:
self._cancel_context = conn.cancel_context
response, round_trip_time = await self._check_with_socket(conn)
if not response.awaitable:
self._rtt_monitor.add_sample(round_trip_time)
await self._rtt_monitor.add_sample(round_trip_time)

avg_rtt, min_rtt = self._rtt_monitor.get()
avg_rtt, min_rtt = await self._rtt_monitor.get()
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
if self._publish:
assert self._listeners is not None
Expand Down Expand Up @@ -413,6 +414,8 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
except asyncio.CancelledError:
raise
except Exception:
# As per the spec, upon encountering an error:
# - An error must not be raised
Expand Down Expand Up @@ -441,28 +444,28 @@ def __init__(self, topology: Topology, topology_settings: TopologySettings, pool
self._pool = pool
self._moving_average = MovingAverage()
self._moving_min = MovingMinimum()
self._lock = _create_lock()
self._lock = _async_create_lock()

async def close(self) -> None:
self.gc_safe_close()
# Increment the generation and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
await self._pool.reset()

def add_sample(self, sample: float) -> None:
async def add_sample(self, sample: float) -> None:
"""Add a RTT sample."""
with self._lock:
async with self._lock:
self._moving_average.add_sample(sample)
self._moving_min.add_sample(sample)

def get(self) -> tuple[Optional[float], float]:
async def get(self) -> tuple[Optional[float], float]:
"""Get the calculated average, or None if no samples yet and the min."""
with self._lock:
async with self._lock:
return self._moving_average.get(), self._moving_min.get()

def reset(self) -> None:
async def reset(self) -> None:
"""Reset the average RTT."""
with self._lock:
async with self._lock:
self._moving_average.reset()
self._moving_min.reset()

Expand All @@ -472,10 +475,12 @@ async def _run(self) -> None:
# heartbeat protocol (MongoDB 4.4+).
# XXX: Skip check if the server is unknown?
rtt = await self._ping()
self.add_sample(rtt)
await self.add_sample(rtt)
except ReferenceError:
# Topology was garbage-collected.
await self.close()
except asyncio.CancelledError:
raise
except Exception:
await self._pool.reset()

Expand Down
2 changes: 2 additions & 0 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,8 @@ def _close_conn(self) -> None:
# shutdown.
try:
self.conn.close()
except asyncio.CancelledError:
raise
except Exception: # noqa: S110
pass

Expand Down
3 changes: 2 additions & 1 deletion pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ async def async_receive_data(
)
for task in pending:
task.cancel()
await asyncio.wait(pending)
if pending:
await asyncio.wait(pending)
if len(done) == 0:
raise socket.timeout("timed out")
if read_task in done:
Expand Down
7 changes: 7 additions & 0 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""
from __future__ import annotations

import asyncio
import contextlib
import os
import warnings
Expand Down Expand Up @@ -2030,6 +2031,8 @@ def _process_kill_cursors(self) -> None:
for address, cursor_id, conn_mgr in pinned_cursors:
try:
self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
# Raise the exception when client is closed so that it
Expand All @@ -2044,6 +2047,8 @@ def _process_kill_cursors(self) -> None:
for address, cursor_ids in address_to_cursor_ids.items():
try:
self._kill_cursors(cursor_ids, address, topology, session=None)
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
raise
Expand All @@ -2058,6 +2063,8 @@ def _process_periodic_tasks(self) -> None:
try:
self._process_kill_cursors()
self._topology.update_pool()
except asyncio.CancelledError:
raise
except Exception as exc:
if isinstance(exc, InvalidOperation) and self._topology._closed:
return
Expand Down
5 changes: 5 additions & 0 deletions pymongo/synchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import asyncio
import atexit
import logging
import time
Expand Down Expand Up @@ -413,6 +414,8 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
if len(seedlist) == 0:
# As per the spec: this should be treated as a failure.
raise Exception
except asyncio.CancelledError:
raise
except Exception:
# As per the spec, upon encountering an error:
# - An error must not be raised
Expand Down Expand Up @@ -476,6 +479,8 @@ def _run(self) -> None:
except ReferenceError:
# Topology was garbage-collected.
self.close()
except asyncio.CancelledError:
raise
except Exception:
self._pool.reset()

Expand Down
2 changes: 2 additions & 0 deletions pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,8 @@ def _close_conn(self) -> None:
# shutdown.
try:
self.conn.close()
except asyncio.CancelledError:
raise
except Exception: # noqa: S110
pass

Expand Down
4 changes: 4 additions & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import asyncio
import gc
import logging
import multiprocessing
import os
import signal
Expand All @@ -25,6 +26,7 @@
import sys
import threading
import time
import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
Expand Down Expand Up @@ -191,6 +193,8 @@ def _connect(self, host, port, **kwargs):
client.close()

def _init_client(self):
self.mongoses = []
self.connection_attempts = []
self.client = self._connect(host, port)
if self.client is not None:
# Return early when connected to dataLake as mongohoused does not
Expand Down
4 changes: 4 additions & 0 deletions test/asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import asyncio
import gc
import logging
import multiprocessing
import os
import signal
Expand All @@ -25,6 +26,7 @@
import sys
import threading
import time
import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
Expand Down Expand Up @@ -191,6 +193,8 @@ async def _connect(self, host, port, **kwargs):
await client.close()

async def _init_client(self):
self.mongoses = []
self.connection_attempts = []
self.client = await self._connect(host, port)
if self.client is not None:
# Return early when connected to dataLake as mongohoused does not
Expand Down
2 changes: 1 addition & 1 deletion test/asynchronous/test_auth_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
sys.path[0:0] = [""]

from test import unittest
from test.unified_format import generate_test_classes
from test.asynchronous.unified_format import generate_test_classes

from pymongo import AsyncMongoClient
from pymongo.asynchronous.auth_oidc import OIDCCallback
Expand Down
2 changes: 1 addition & 1 deletion test/asynchronous/test_change_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
async_client_context,
unittest,
)
from test.unified_format import generate_test_classes
from test.asynchronous.unified_format import generate_test_classes
from test.utils import (
AllowListEventListener,
EventListener,
Expand Down
2 changes: 1 addition & 1 deletion test/asynchronous/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2580,7 +2580,7 @@ async def test_direct_client_maintains_pool_to_arbiter(self):
await async_wait_until(lambda: len(c.nodes) == 1, "connect")
self.assertEqual(await c.address, ("c", 3))
# Assert that we create 1 pooled connection.
listener.wait_for_event(monitoring.ConnectionReadyEvent, 1)
await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1)
self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1)
arbiter = c._topology.get_server_by_address(("c", 3))
self.assertEqual(len(arbiter.pool.conns), 1)
Expand Down
2 changes: 1 addition & 1 deletion test/asynchronous/test_connection_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
sys.path[0:0] = [""]

from test import unittest
from test.unified_format import generate_test_classes
from test.asynchronous.unified_format import generate_test_classes

_IS_SYNC = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest):
listener: CMAPListener
coll: AsyncCollection

async def asyncTearDown(self):
await reset_client_context()

@async_client_context.require_replica_set
async def asyncSetUp(self):
self.listener = CMAPListener()
Expand Down
6 changes: 6 additions & 0 deletions test/asynchronous/test_create_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ async def test_store_events_as_entities(self):
self.assertGreater(len(final_entity_map["events1"]), 0)
for event in final_entity_map["events1"]:
self.assertIn("PoolCreatedEvent", event["name"])
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
await client.close()

async def test_store_all_others_as_entities(self):
self.scenario_runner = UnifiedSpecTestMixinV1()
Expand Down Expand Up @@ -122,6 +125,9 @@ async def test_store_all_others_as_entities(self):
self.assertEqual(entity_map["failures"], [])
self.assertEqual(entity_map["successes"], 2)
self.assertEqual(entity_map["iterations"], 5)
if self.scenario_runner.mongos_clients:
for client in self.scenario_runner.mongos_clients:
await client.close()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion test/asynchronous/test_crud_unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
sys.path[0:0] = [""]

from test import unittest
from test.unified_format import generate_test_classes
from test.asynchronous.unified_format import generate_test_classes

_IS_SYNC = False

Expand Down
2 changes: 1 addition & 1 deletion test/asynchronous/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
unittest,
)
from test.asynchronous.test_bulk import AsyncBulkTestBase
from test.asynchronous.unified_format import generate_test_classes
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
from test.helpers import (
AWS_CREDS,
Expand All @@ -56,7 +57,6 @@
KMIP_CREDS,
LOCAL_MASTER_KEY,
)
from test.unified_format import generate_test_classes
from test.utils import (
AllowListEventListener,
OvertCommandListener,
Expand Down
Loading
Loading