Skip to content

Commit 547e950

Browse files
committed
Explicitly raise asyncio.CancelledError when tasks are being cancelled
1 parent ce51864 commit 547e950

File tree

11 files changed

+48
-13
lines changed

11 files changed

+48
-13
lines changed

pymongo/asynchronous/client_bulk.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,6 @@ async def _process_results_cursor(
476476
if op_type == "delete":
477477
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
478478
full_result[f"{op_type}Results"][original_index] = res
479-
480479
except Exception as exc:
481480
# Attempt to close the cursor, then raise top-level error.
482481
if cmd_cursor.alive:

pymongo/asynchronous/encryption.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Support for explicit client-side field level encryption."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import contextlib
1920
import enum
2021
import socket
@@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
111112
# BSON encoding/decoding errors are unrelated to encryption so
112113
# we should propagate them unchanged.
113114
raise
115+
except asyncio.CancelledError:
116+
raise
114117
except Exception as exc:
115118
raise EncryptionError(exc) from exc
116119

@@ -200,6 +203,8 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
200203
conn.close()
201204
except (PyMongoError, MongoCryptError):
202205
raise # Propagate pymongo errors directly.
206+
except asyncio.CancelledError:
207+
raise
203208
except Exception as error:
204209
# Wrap I/O errors in PyMongo exceptions.
205210
_raise_connection_failure((host, port), error)
@@ -722,6 +727,8 @@ async def create_encrypted_collection(
722727
await database.create_collection(name=name, **kwargs),
723728
encrypted_fields,
724729
)
730+
except asyncio.CancelledError:
731+
raise
725732
except Exception as exc:
726733
raise EncryptedCollectionError(exc, encrypted_fields) from exc
727734

pymongo/asynchronous/monitor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ async def _run(self) -> None:
238238
except ReferenceError:
239239
# Topology was garbage-collected.
240240
await self.close()
241+
finally:
242+
if self._executor._stopped:
243+
await self._rtt_monitor.close()
241244

242245
async def _check_server(self) -> ServerDescription:
243246
"""Call hello or read the next streaming response.
@@ -254,6 +257,8 @@ async def _check_server(self) -> ServerDescription:
254257
details = cast(Mapping[str, Any], exc.details)
255258
await self._topology.receive_cluster_time(details.get("$clusterTime"))
256259
raise
260+
except asyncio.CancelledError:
261+
raise
257262
except ReferenceError:
258263
raise
259264
except Exception as error:

pymongo/periodic_executor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ def __repr__(self) -> str:
6161
def open(self) -> None:
6262
"""Start. Multiple calls have no effect."""
6363
self._stopped = False
64-
started = self._task and not self._task.done()
6564

66-
if not started:
65+
if self._task is None or (
66+
self._task.done() and not self._task.cancelled() and not self._task.cancelling()
67+
):
6768
self._task = asyncio.get_event_loop().create_task(self._run(), name=self._name)
6869

6970
def close(self, dummy: Any = None) -> None:
@@ -83,7 +84,7 @@ async def join(self, timeout: Optional[int] = None) -> None:
8384
pass
8485
except asyncio.exceptions.CancelledError:
8586
# Task was already finished, or not yet started.
86-
pass
87+
raise
8788

8889
def wake(self) -> None:
8990
"""Execute the target function soon."""
@@ -97,6 +98,8 @@ def skip_sleep(self) -> None:
9798

9899
async def _run(self) -> None:
99100
while not self._stopped:
101+
if self._task and self._task.cancelling():
102+
raise asyncio.CancelledError
100103
try:
101104
if not await self._target():
102105
self._stopped = True

pymongo/synchronous/client_bulk.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,6 @@ def _process_results_cursor(
474474
if op_type == "delete":
475475
res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment]
476476
full_result[f"{op_type}Results"][original_index] = res
477-
478477
except Exception as exc:
479478
# Attempt to close the cursor, then raise top-level error.
480479
if cmd_cursor.alive:

pymongo/synchronous/encryption.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Support for explicit client-side field level encryption."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import contextlib
1920
import enum
2021
import socket
@@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]:
111112
# BSON encoding/decoding errors are unrelated to encryption so
112113
# we should propagate them unchanged.
113114
raise
115+
except asyncio.CancelledError:
116+
raise
114117
except Exception as exc:
115118
raise EncryptionError(exc) from exc
116119

@@ -200,6 +203,8 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
200203
conn.close()
201204
except (PyMongoError, MongoCryptError):
202205
raise # Propagate pymongo errors directly.
206+
except asyncio.CancelledError:
207+
raise
203208
except Exception as error:
204209
# Wrap I/O errors in PyMongo exceptions.
205210
_raise_connection_failure((host, port), error)
@@ -716,6 +721,8 @@ def create_encrypted_collection(
716721
database.create_collection(name=name, **kwargs),
717722
encrypted_fields,
718723
)
724+
except asyncio.CancelledError:
725+
raise
719726
except Exception as exc:
720727
raise EncryptedCollectionError(exc, encrypted_fields) from exc
721728

pymongo/synchronous/monitor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ def _run(self) -> None:
238238
except ReferenceError:
239239
# Topology was garbage-collected.
240240
self.close()
241+
finally:
242+
if self._executor._stopped:
243+
self._rtt_monitor.close()
241244

242245
def _check_server(self) -> ServerDescription:
243246
"""Call hello or read the next streaming response.
@@ -254,6 +257,8 @@ def _check_server(self) -> ServerDescription:
254257
details = cast(Mapping[str, Any], exc.details)
255258
self._topology.receive_cluster_time(details.get("$clusterTime"))
256259
raise
260+
except asyncio.CancelledError:
261+
raise
257262
except ReferenceError:
258263
raise
259264
except Exception as error:

test/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -868,8 +868,9 @@ def reset_client_context():
868868
if _IS_SYNC:
869869
# sync tests don't need to reset a client context
870870
return
871-
client_context.client.close()
872-
client_context.client = None
871+
elif client_context.client is not None:
872+
client_context.client.close()
873+
client_context.client = None
873874
client_context._init_client()
874875

875876

@@ -1135,7 +1136,7 @@ class IntegrationTest(PyMongoTestCase):
11351136

11361137
@client_context.require_connection
11371138
def setUp(self) -> None:
1138-
if not _IS_SYNC and client_context.client is not None:
1139+
if not _IS_SYNC:
11391140
reset_client_context()
11401141
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
11411142
raise SkipTest("this test does not support load balancers")
@@ -1210,7 +1211,6 @@ def teardown():
12101211
c.drop_database("pymongo_test_mike")
12111212
c.drop_database("pymongo_test_bernie")
12121213
c.close()
1213-
12141214
print_running_clients()
12151215

12161216

test/asynchronous/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -870,8 +870,9 @@ async def reset_client_context():
870870
if _IS_SYNC:
871871
# sync tests don't need to reset a client context
872872
return
873-
await async_client_context.client.close()
874-
async_client_context.client = None
873+
elif async_client_context.client is not None:
874+
await async_client_context.client.close()
875+
async_client_context.client = None
875876
await async_client_context._init_client()
876877

877878

@@ -1153,7 +1154,7 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
11531154

11541155
@async_client_context.require_connection
11551156
async def asyncSetUp(self) -> None:
1156-
if not _IS_SYNC and async_client_context.client is not None:
1157+
if not _IS_SYNC:
11571158
await reset_client_context()
11581159
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
11591160
raise SkipTest("this test does not support load balancers")
@@ -1228,7 +1229,6 @@ async def async_teardown():
12281229
await c.drop_database("pymongo_test_mike")
12291230
await c.drop_database("pymongo_test_bernie")
12301231
await c.close()
1231-
12321232
print_running_clients()
12331233

12341234

test/asynchronous/test_client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,7 @@ async def get_x(db):
12801280
async def test_server_selection_timeout(self):
12811281
client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False)
12821282
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
1283+
await client.close()
12831284

12841285
client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False)
12851286

@@ -1292,18 +1293,22 @@ async def test_server_selection_timeout(self):
12921293
self.assertRaises(
12931294
ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False
12941295
)
1296+
await client.close()
12951297

12961298
client = AsyncMongoClient(
12971299
"mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False
12981300
)
12991301
self.assertAlmostEqual(0.1, client.options.server_selection_timeout)
1302+
await client.close()
13001303

13011304
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False)
13021305
self.assertAlmostEqual(0, client.options.server_selection_timeout)
1306+
await client.close()
13031307

13041308
# Test invalid timeout in URI ignored and set to default.
13051309
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False)
13061310
self.assertAlmostEqual(30, client.options.server_selection_timeout)
1311+
await client.close()
13071312

13081313
client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False)
13091314
self.assertAlmostEqual(30, client.options.server_selection_timeout)

0 commit comments

Comments
 (0)