Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymongo/asynchronous/change_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ async def try_next(self) -> Optional[_DocumentType]:
if not _resumable(exc) and not exc.timeout:
await self.close()
raise
except Exception:
except BaseException:
Copy link
Member

@ShaneHarvey ShaneHarvey Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everywhere we catch BaseException, can you add a one line comment to explain it's intentional? Something like:

# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.

await self.close()
raise

Expand Down
2 changes: 1 addition & 1 deletion pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ async def callback(session, custom_arg, custom_kwarg=None):
)
try:
ret = await callback(self)
except Exception as exc:
except BaseException as exc:
if self.in_transaction:
await self.abort_transaction()
if (
Expand Down
2 changes: 1 addition & 1 deletion pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
self._killed = True
await self.close()
raise
except Exception:
except BaseException:
await self.close()
raise

Expand Down
2 changes: 1 addition & 1 deletion pymongo/synchronous/change_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def try_next(self) -> Optional[_DocumentType]:
if not _resumable(exc) and not exc.timeout:
self.close()
raise
except Exception:
except BaseException:
self.close()
raise

Expand Down
2 changes: 1 addition & 1 deletion pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def callback(session, custom_arg, custom_kwarg=None):
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
try:
ret = callback(self)
except Exception as exc:
except BaseException as exc:
if self.in_transaction:
self.abort_transaction()
if (
Expand Down
2 changes: 1 addition & 1 deletion pymongo/synchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
self._killed = True
self.close()
raise
except Exception:
except BaseException:
self.close()
raise

Expand Down
136 changes: 136 additions & 0 deletions test/asynchronous/test_async_cancellation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test that async cancellation performed by users clean up resources correctly."""
from __future__ import annotations

import asyncio
import sys
from test.utils import async_get_pool, delay, one

sys.path[0:0] = [""]

from test.asynchronous import AsyncIntegrationTest, async_client_context, connected


class TestAsyncCancellation(AsyncIntegrationTest):
async def test_async_cancellation_closes_connection(self):
client = await self.async_rs_or_single_client(maxPoolSize=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of these tests can reuse the global test client instead of creating new ones.

pool = await async_get_pool(client)
await connected(client)
conn = one(pool.conns)
await client.db.test.insert_one({"x": 1})
self.addAsyncCleanup(client.db.test.drop)

async def task():
await client.db.test.find_one({"$where": delay(0.2)})

task = asyncio.create_task(task())

await asyncio.sleep(0.1)

task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertTrue(conn.closed)

@async_client_context.require_transactions
async def test_async_cancellation_aborts_transaction(self):
client = await self.async_rs_or_single_client()
await connected(client)
await client.db.test.insert_one({"x": 1})
self.addAsyncCleanup(client.db.test.drop)

session = client.start_session()

async def callback(session):
await client.db.test.find_one({"$where": delay(0.2)}, session=session)

async def task():
await session.with_transaction(callback)

task = asyncio.create_task(task())

await asyncio.sleep(0.1)

task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertFalse(session.in_transaction)

@async_client_context.require_failCommand_blockConnection
async def test_async_cancellation_closes_cursor(self):
client = await self.async_rs_or_single_client()
await connected(client)
for _ in range(2):
await client.db.test.insert_one({"x": 1})
self.addAsyncCleanup(client.db.test.drop)

cursor = client.db.test.find({}, batch_size=1)
await cursor.next()

# Make sure getMore commands block
fail_command = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
}

async def task():
async with self.fail_point(fail_command):
await cursor.next()

task = asyncio.create_task(task())

await asyncio.sleep(0.1)

task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertTrue(cursor._killed)

@async_client_context.require_change_streams
@async_client_context.require_failCommand_blockConnection
async def test_async_cancellation_closes_change_stream(self):
client = await self.async_rs_or_single_client()
await connected(client)
self.addAsyncCleanup(client.db.test.drop)

change_stream = await client.db.test.watch(batch_size=2)

# Make sure getMore commands block
fail_command = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
}

async def task():
async with self.fail_point(fail_command):
for _ in range(2):
await client.db.test.insert_one({"x": 1})
await change_stream.next()

task = asyncio.create_task(task())

await asyncio.sleep(0.1)

task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertTrue(change_stream._closed)
2 changes: 1 addition & 1 deletion tools/synchro.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@

def async_only_test(f: str) -> bool:
"""Return True for async tests that should not be converted to sync."""
return f in ["test_locks.py", "test_concurrency.py"]
return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"]


test_files = [
Expand Down
Loading