Skip to content

Commit 1276619

Browse files
committed
Subclass asyncio.Task
1 parent 547e950 commit 1276619

File tree

3 files changed

+49
-7
lines changed

3 files changed

+49
-7
lines changed

pymongo/_asyncio_task.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2024-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A custom asyncio.Task that allows checking if a task has been sent a cancellation request.
16+
Can be removed once we drop Python 3.10 support in favor of asyncio.Task.cancelling."""
17+
18+
19+
from __future__ import annotations
20+
21+
import asyncio
22+
from typing import Any, Coroutine, Optional
23+
24+
25+
class _Task(asyncio.Task):
26+
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
27+
super().__init__(coro, name=name)
28+
self._cancelled: bool = False
29+
asyncio._register_task(self)
30+
31+
def cancel(self, msg: Optional[str] = None) -> bool:
32+
self._cancelled = True
33+
return super().cancel(msg=msg)
34+
35+
def is_cancelled(self) -> bool:
36+
return self._cancelled
37+
38+
39+
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> _Task:
40+
return _Task(coro, name=name)

pymongo/network_layer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030

3131
from pymongo import _csot, ssl_support
32+
from pymongo._asyncio_task import create_task
3233
from pymongo.errors import _OperationCancelled
3334
from pymongo.socket_checker import _errno_from_exception
3435

@@ -259,12 +260,12 @@ async def async_receive_data(
259260

260261
sock.settimeout(0.0)
261262
loop = asyncio.get_event_loop()
262-
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
263+
cancellation_task = create_task(_poll_cancellation(conn))
263264
try:
264265
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
265-
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
266+
read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
266267
else:
267-
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
268+
read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
268269
tasks = [read_task, cancellation_task]
269270
done, pending = await asyncio.wait(
270271
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED

pymongo/periodic_executor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import weakref
2424
from typing import Any, Optional
2525

26+
from pymongo._asyncio_task import _Task, create_task
2627
from pymongo.lock import _create_lock
2728

2829
_IS_SYNC = False
@@ -51,7 +52,7 @@ def __init__(
5152
self._min_interval = min_interval
5253
self._target = target
5354
self._stopped = False
54-
self._task: Optional[asyncio.Task] = None
55+
self._task: Optional[_Task] = None
5556
self._name = name
5657
self._skip_sleep = False
5758

@@ -63,9 +64,9 @@ def open(self) -> None:
6364
self._stopped = False
6465

6566
if self._task is None or (
66-
self._task.done() and not self._task.cancelled() and not self._task.cancelling()
67+
self._task.done() and not self._task.cancelled() and not self._task.is_cancelled()
6768
):
68-
self._task = asyncio.get_event_loop().create_task(self._run(), name=self._name)
69+
self._task = create_task(self._run(), name=self._name)
6970

7071
def close(self, dummy: Any = None) -> None:
7172
"""Stop. To restart, call open().
@@ -98,7 +99,7 @@ def skip_sleep(self) -> None:
9899

99100
async def _run(self) -> None:
100101
while not self._stopped:
101-
if self._task and self._task.cancelling():
102+
if self._task and self._task.is_cancelled():
102103
raise asyncio.CancelledError
103104
try:
104105
if not await self._target():

0 commit comments

Comments
 (0)