Skip to content

Commit 29c16db

Browse files
authored
PYTHON-4981 - Create workaround for asyncio.Task.cancelling support in older Python versions (#2009)
1 parent ce51864 commit 29c16db

File tree

13 files changed

+103
-17
lines changed

13 files changed

+103
-17
lines changed

pymongo/_asyncio_task.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2024-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A custom asyncio.Task that allows checking if a task has been sent a cancellation request.
16+
Can be removed once we drop Python 3.10 support in favor of asyncio.Task.cancelling."""
17+
18+
19+
from __future__ import annotations
20+
21+
import asyncio
22+
import sys
23+
from typing import Any, Coroutine, Optional
24+
25+
26+
# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered
27+
class _Task(asyncio.Task):
28+
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
29+
super().__init__(coro, name=name)
30+
self._cancel_requests = 0
31+
asyncio._register_task(self)
32+
33+
def cancel(self, msg: Optional[str] = None) -> bool:
34+
self._cancel_requests += 1
35+
return super().cancel(msg=msg)
36+
37+
def uncancel(self) -> int:
38+
if self._cancel_requests > 0:
39+
self._cancel_requests -= 1
40+
return self._cancel_requests
41+
42+
def cancelling(self) -> int:
43+
return self._cancel_requests
44+
45+
46+
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task:
47+
if sys.version_info >= (3, 11):
48+
return asyncio.create_task(coro, name=name)
49+
return _Task(coro, name=name)

pymongo/asynchronous/client_bulk.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,6 @@ async def _process_results_cursor(
476476
if op_type == "delete":
477477
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
478478
full_result[f"{op_type}Results"][original_index] = res
479-
480479
except Exception as exc:
481480
# Attempt to close the cursor, then raise top-level error.
482481
if cmd_cursor.alive:

pymongo/asynchronous/encryption.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Support for explicit client-side field level encryption."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import contextlib
1920
import enum
2021
import socket
@@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
111112
# BSON encoding/decoding errors are unrelated to encryption so
112113
# we should propagate them unchanged.
113114
raise
115+
except asyncio.CancelledError:
116+
raise
114117
except Exception as exc:
115118
raise EncryptionError(exc) from exc
116119

@@ -200,6 +203,8 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
200203
conn.close()
201204
except (PyMongoError, MongoCryptError):
202205
raise # Propagate pymongo errors directly.
206+
except asyncio.CancelledError:
207+
raise
203208
except Exception as error:
204209
# Wrap I/O errors in PyMongo exceptions.
205210
_raise_connection_failure((host, port), error)
@@ -722,6 +727,8 @@ async def create_encrypted_collection(
722727
await database.create_collection(name=name, **kwargs),
723728
encrypted_fields,
724729
)
730+
except asyncio.CancelledError:
731+
raise
725732
except Exception as exc:
726733
raise EncryptedCollectionError(exc, encrypted_fields) from exc
727734

pymongo/asynchronous/monitor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ async def _run(self) -> None:
238238
except ReferenceError:
239239
# Topology was garbage-collected.
240240
await self.close()
241+
finally:
242+
if self._executor._stopped:
243+
await self._rtt_monitor.close()
241244

242245
async def _check_server(self) -> ServerDescription:
243246
"""Call hello or read the next streaming response.
@@ -254,6 +257,8 @@ async def _check_server(self) -> ServerDescription:
254257
details = cast(Mapping[str, Any], exc.details)
255258
await self._topology.receive_cluster_time(details.get("$clusterTime"))
256259
raise
260+
except asyncio.CancelledError:
261+
raise
257262
except ReferenceError:
258263
raise
259264
except Exception as error:

pymongo/network_layer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030

3131
from pymongo import _csot, ssl_support
32+
from pymongo._asyncio_task import create_task
3233
from pymongo.errors import _OperationCancelled
3334
from pymongo.socket_checker import _errno_from_exception
3435

@@ -259,12 +260,12 @@ async def async_receive_data(
259260

260261
sock.settimeout(0.0)
261262
loop = asyncio.get_event_loop()
262-
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
263+
cancellation_task = create_task(_poll_cancellation(conn))
263264
try:
264265
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
265-
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
266+
read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
266267
else:
267-
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
268+
read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
268269
tasks = [read_task, cancellation_task]
269270
done, pending = await asyncio.wait(
270271
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED

pymongo/periodic_executor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import weakref
2424
from typing import Any, Optional
2525

26+
from pymongo._asyncio_task import create_task
2627
from pymongo.lock import _create_lock
2728

2829
_IS_SYNC = False
@@ -61,10 +62,11 @@ def __repr__(self) -> str:
6162
def open(self) -> None:
6263
"""Start. Multiple calls have no effect."""
6364
self._stopped = False
64-
started = self._task and not self._task.done()
6565

66-
if not started:
67-
self._task = asyncio.get_event_loop().create_task(self._run(), name=self._name)
66+
if self._task is None or (
67+
self._task.done() and not self._task.cancelled() and not self._task.cancelling() # type: ignore[unused-ignore, attr-defined]
68+
):
69+
self._task = create_task(self._run(), name=self._name)
6870

6971
def close(self, dummy: Any = None) -> None:
7072
"""Stop. To restart, call open().
@@ -83,7 +85,7 @@ async def join(self, timeout: Optional[int] = None) -> None:
8385
pass
8486
except asyncio.exceptions.CancelledError:
8587
# Task was already finished, or not yet started.
86-
pass
88+
raise
8789

8890
def wake(self) -> None:
8991
"""Execute the target function soon."""
@@ -97,6 +99,8 @@ def skip_sleep(self) -> None:
9799

98100
async def _run(self) -> None:
99101
while not self._stopped:
102+
if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined]
103+
raise asyncio.CancelledError
100104
try:
101105
if not await self._target():
102106
self._stopped = True

pymongo/synchronous/client_bulk.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,6 @@ def _process_results_cursor(
474474
if op_type == "delete":
475475
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
476476
full_result[f"{op_type}Results"][original_index] = res
477-
478477
except Exception as exc:
479478
# Attempt to close the cursor, then raise top-level error.
480479
if cmd_cursor.alive:

pymongo/synchronous/encryption.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Support for explicit client-side field level encryption."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import contextlib
1920
import enum
2021
import socket
@@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
111112
# BSON encoding/decoding errors are unrelated to encryption so
112113
# we should propagate them unchanged.
113114
raise
115+
except asyncio.CancelledError:
116+
raise
114117
except Exception as exc:
115118
raise EncryptionError(exc) from exc
116119

@@ -200,6 +203,8 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
200203
conn.close()
201204
except (PyMongoError, MongoCryptError):
202205
raise # Propagate pymongo errors directly.
206+
except asyncio.CancelledError:
207+
raise
203208
except Exception as error:
204209
# Wrap I/O errors in PyMongo exceptions.
205210
_raise_connection_failure((host, port), error)
@@ -716,6 +721,8 @@ def create_encrypted_collection(
716721
database.create_collection(name=name, **kwargs),
717722
encrypted_fields,
718723
)
724+
except asyncio.CancelledError:
725+
raise
719726
except Exception as exc:
720727
raise EncryptedCollectionError(exc, encrypted_fields) from exc
721728

pymongo/synchronous/monitor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ def _run(self) -> None:
238238
except ReferenceError:
239239
# Topology was garbage-collected.
240240
self.close()
241+
finally:
242+
if self._executor._stopped:
243+
self._rtt_monitor.close()
241244

242245
def _check_server(self) -> ServerDescription:
243246
"""Call hello or read the next streaming response.
@@ -254,6 +257,8 @@ def _check_server(self) -> ServerDescription:
254257
details = cast(Mapping[str, Any], exc.details)
255258
self._topology.receive_cluster_time(details.get("$clusterTime"))
256259
raise
260+
except asyncio.CancelledError:
261+
raise
257262
except ReferenceError:
258263
raise
259264
except Exception as error:

test/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -868,8 +868,9 @@ def reset_client_context():
868868
if _IS_SYNC:
869869
# sync tests don't need to reset a client context
870870
return
871-
client_context.client.close()
872-
client_context.client = None
871+
elif client_context.client is not None:
872+
client_context.client.close()
873+
client_context.client = None
873874
client_context._init_client()
874875

875876

@@ -1135,7 +1136,7 @@ class IntegrationTest(PyMongoTestCase):
11351136

11361137
@client_context.require_connection
11371138
def setUp(self) -> None:
1138-
if not _IS_SYNC and client_context.client is not None:
1139+
if not _IS_SYNC:
11391140
reset_client_context()
11401141
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
11411142
raise SkipTest("this test does not support load balancers")
@@ -1210,7 +1211,6 @@ def teardown():
12101211
c.drop_database("pymongo_test_mike")
12111212
c.drop_database("pymongo_test_bernie")
12121213
c.close()
1213-
12141214
print_running_clients()
12151215

12161216

0 commit comments

Comments
 (0)