diff --git a/pymongo/_asyncio_task.py b/pymongo/_asyncio_task.py new file mode 100644 index 0000000000..8e457763d9 --- /dev/null +++ b/pymongo/_asyncio_task.py @@ -0,0 +1,49 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A custom asyncio.Task that allows checking if a task has been sent a cancellation request. +Can be removed once we drop Python 3.10 support in favor of asyncio.Task.cancelling.""" + + +from __future__ import annotations + +import asyncio +import sys +from typing import Any, Coroutine, Optional + + +# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered +class _Task(asyncio.Task): + def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None: + super().__init__(coro, name=name) + self._cancel_requests = 0 + asyncio._register_task(self) + + def cancel(self, msg: Optional[str] = None) -> bool: + self._cancel_requests += 1 + return super().cancel(msg=msg) + + def uncancel(self) -> int: + if self._cancel_requests > 0: + self._cancel_requests -= 1 + return self._cancel_requests + + def cancelling(self) -> int: + return self._cancel_requests + + +def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task: + if sys.version_info >= (3, 11): + return asyncio.create_task(coro, name=name) + return _Task(coro, name=name) diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 0dcdaa6c07..45824256da 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -476,7 +476,6 @@ async def _process_results_cursor( if op_type == "delete": res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment] full_result[f"{op_type}Results"][original_index] = res - except Exception as exc: # Attempt to close the cursor, then raise top-level error. if cmd_cursor.alive: diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 735e543047..4802c3f54e 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -15,6 +15,7 @@ """Support for explicit client-side field level encryption.""" from __future__ import annotations +import asyncio import contextlib import enum import socket @@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]: # BSON encoding/decoding errors are unrelated to encryption so # we should propagate them unchanged. raise + except asyncio.CancelledError: + raise except Exception as exc: raise EncryptionError(exc) from exc @@ -200,6 +203,8 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: conn.close() except (PyMongoError, MongoCryptError): raise # Propagate pymongo errors directly. + except asyncio.CancelledError: + raise except Exception as error: # Wrap I/O errors in PyMongo exceptions. _raise_connection_failure((host, port), error) @@ -722,6 +727,8 @@ async def create_encrypted_collection( await database.create_collection(name=name, **kwargs), encrypted_fields, ) + except asyncio.CancelledError: + raise except Exception as exc: raise EncryptedCollectionError(exc, encrypted_fields) from exc diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 2ad57b03e7..ad1bc70aba 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -238,6 +238,9 @@ async def _run(self) -> None: except ReferenceError: # Topology was garbage-collected. await self.close() + finally: + if self._executor._stopped: + await self._rtt_monitor.close() async def _check_server(self) -> ServerDescription: """Call hello or read the next streaming response. @@ -254,6 +257,8 @@ async def _check_server(self) -> ServerDescription: details = cast(Mapping[str, Any], exc.details) await self._topology.receive_cluster_time(details.get("$clusterTime")) raise + except asyncio.CancelledError: + raise except ReferenceError: raise except Exception as error: diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 377689047b..6ab6db2f7d 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -29,6 +29,7 @@ ) from pymongo import _csot, ssl_support +from pymongo._asyncio_task import create_task from pymongo.errors import _OperationCancelled from pymongo.socket_checker import _errno_from_exception @@ -259,12 +260,12 @@ async def async_receive_data( sock.settimeout(0.0) loop = asyncio.get_event_loop() - cancellation_task = asyncio.create_task(_poll_cancellation(conn)) + cancellation_task = create_task(_poll_cancellation(conn)) try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] + read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] else: - read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] + read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] tasks = [read_task, cancellation_task] done, pending = await asyncio.wait( tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index 216a4457c7..2f89b91deb 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -23,6 +23,7 @@ import weakref from typing import Any, Optional +from pymongo._asyncio_task import create_task from pymongo.lock import _create_lock _IS_SYNC = False @@ -61,10 +62,11 @@ def __repr__(self) -> str: def open(self) -> None: """Start. Multiple calls have no effect.""" self._stopped = False - started = self._task and not self._task.done() - if not started: - self._task = asyncio.get_event_loop().create_task(self._run(), name=self._name) + if self._task is None or ( + self._task.done() and not self._task.cancelled() and not self._task.cancelling() # type: ignore[unused-ignore, attr-defined] + ): + self._task = create_task(self._run(), name=self._name) def close(self, dummy: Any = None) -> None: """Stop. To restart, call open(). @@ -83,7 +85,7 @@ async def join(self, timeout: Optional[int] = None) -> None: pass except asyncio.exceptions.CancelledError: # Task was already finished, or not yet started. - pass + raise def wake(self) -> None: """Execute the target function soon.""" @@ -97,6 +99,8 @@ def skip_sleep(self) -> None: async def _run(self) -> None: while not self._stopped: + if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined] + raise asyncio.CancelledError try: if not await self._target(): self._stopped = True diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 625e8429eb..9f6e3f7cf0 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -474,7 +474,6 @@ def _process_results_cursor( if op_type == "delete": res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment] full_result[f"{op_type}Results"][original_index] = res - except Exception as exc: # Attempt to close the cursor, then raise top-level error. if cmd_cursor.alive: diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 506ff8bcba..09d0c0f2fd 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -15,6 +15,7 @@ """Support for explicit client-side field level encryption.""" from __future__ import annotations +import asyncio import contextlib import enum import socket @@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]: # BSON encoding/decoding errors are unrelated to encryption so # we should propagate them unchanged. raise + except asyncio.CancelledError: + raise except Exception as exc: raise EncryptionError(exc) from exc @@ -200,6 +203,8 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: conn.close() except (PyMongoError, MongoCryptError): raise # Propagate pymongo errors directly. + except asyncio.CancelledError: + raise except Exception as error: # Wrap I/O errors in PyMongo exceptions. _raise_connection_failure((host, port), error) @@ -716,6 +721,8 @@ def create_encrypted_collection( database.create_collection(name=name, **kwargs), encrypted_fields, ) + except asyncio.CancelledError: + raise except Exception as exc: raise EncryptedCollectionError(exc, encrypted_fields) from exc diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index a0b7635ab1..df4130d4ab 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -238,6 +238,9 @@ def _run(self) -> None: except ReferenceError: # Topology was garbage-collected. self.close() + finally: + if self._executor._stopped: + self._rtt_monitor.close() def _check_server(self) -> ServerDescription: """Call hello or read the next streaming response. @@ -254,6 +257,8 @@ def _check_server(self) -> ServerDescription: details = cast(Mapping[str, Any], exc.details) self._topology.receive_cluster_time(details.get("$clusterTime")) raise + except asyncio.CancelledError: + raise except ReferenceError: raise except Exception as error: diff --git a/test/__init__.py b/test/__init__.py index dba3312424..d3a63db2d5 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -868,8 +868,9 @@ def reset_client_context(): if _IS_SYNC: # sync tests don't need to reset a client context return - client_context.client.close() - client_context.client = None + elif client_context.client is not None: + client_context.client.close() + client_context.client = None client_context._init_client() @@ -1135,7 +1136,7 @@ class IntegrationTest(PyMongoTestCase): @client_context.require_connection def setUp(self) -> None: - if not _IS_SYNC and client_context.client is not None: + if not _IS_SYNC: reset_client_context() if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") @@ -1210,7 +1211,6 @@ def teardown(): c.drop_database("pymongo_test_mike") c.drop_database("pymongo_test_bernie") c.close() - print_running_clients() diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index bed49de161..73e2824742 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -870,8 +870,9 @@ async def reset_client_context(): if _IS_SYNC: # sync tests don't need to reset a client context return - await async_client_context.client.close() - async_client_context.client = None + elif async_client_context.client is not None: + await async_client_context.client.close() + async_client_context.client = None await async_client_context._init_client() @@ -1153,7 +1154,7 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): @async_client_context.require_connection async def asyncSetUp(self) -> None: - if not _IS_SYNC and async_client_context.client is not None: + if not _IS_SYNC: await reset_client_context() if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") @@ -1228,7 +1229,6 @@ async def async_teardown(): await c.drop_database("pymongo_test_mike") await c.drop_database("pymongo_test_bernie") await c.close() - print_running_clients() diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 292a78d645..db232386ee 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -1280,6 +1280,7 @@ async def get_x(db): async def test_server_selection_timeout(self): client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + await client.close() client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False) @@ -1292,18 +1293,22 @@ async def test_server_selection_timeout(self): self.assertRaises( ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False ) + await client.close() client = AsyncMongoClient( "mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False ) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + await client.close() client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) self.assertAlmostEqual(0, client.options.server_selection_timeout) + await client.close() # Test invalid timeout in URI ignored and set to default. client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) + await client.close() client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) diff --git a/test/test_client.py b/test/test_client.py index d41b0bbfda..5ec425f312 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1243,6 +1243,7 @@ def get_x(db): def test_server_selection_timeout(self): client = MongoClient(serverSelectionTimeoutMS=100, connect=False) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + client.close() client = MongoClient(serverSelectionTimeoutMS=0, connect=False) @@ -1253,16 +1254,20 @@ def test_server_selection_timeout(self): self.assertRaises( ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False ) + client.close() client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + client.close() client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) self.assertAlmostEqual(0, client.options.server_selection_timeout) + client.close() # Test invalid timeout in URI ignored and set to default. client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) + client.close() client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout)