Skip to content

PYTHON-4745 - Test behavior of async task cancellation #2136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymongo/asynchronous/change_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ async def try_next(self) -> Optional[_DocumentType]:
if not _resumable(exc) and not exc.timeout:
await self.close()
raise
except Exception:
except BaseException:
Copy link
Member

@ShaneHarvey ShaneHarvey Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everywhere we catch BaseException, can you add a one line comment to explain it's intentional? Something like:

# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.

await self.close()
raise

Expand Down
2 changes: 1 addition & 1 deletion pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ async def callback(session, custom_arg, custom_kwarg=None):
)
try:
ret = await callback(self)
except Exception as exc:
except BaseException as exc:
if self.in_transaction:
await self.abort_transaction()
if (
Expand Down
2 changes: 1 addition & 1 deletion pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
self._killed = True
await self.close()
raise
except Exception:
except BaseException:
await self.close()
raise

Expand Down
2 changes: 1 addition & 1 deletion pymongo/synchronous/change_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def try_next(self) -> Optional[_DocumentType]:
if not _resumable(exc) and not exc.timeout:
self.close()
raise
except Exception:
except BaseException:
self.close()
raise

Expand Down
2 changes: 1 addition & 1 deletion pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def callback(session, custom_arg, custom_kwarg=None):
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
try:
ret = callback(self)
except Exception as exc:
except BaseException as exc:
if self.in_transaction:
self.abort_transaction()
if (
Expand Down
2 changes: 1 addition & 1 deletion pymongo/synchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
self._killed = True
self.close()
raise
except Exception:
except BaseException:
self.close()
raise

Expand Down
136 changes: 136 additions & 0 deletions test/asynchronous/test_async_cancellation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test that async cancellation performed by users clean up resources correctly."""
from __future__ import annotations

import asyncio
import sys
from test.utils import async_get_pool, delay, one

sys.path[0:0] = [""]

from test.asynchronous import AsyncIntegrationTest, async_client_context, connected


class TestAsyncCancellation(AsyncIntegrationTest):
async def test_async_cancellation_closes_connection(self):
client = await self.async_rs_or_single_client(maxPoolSize=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of these tests can reuse the global test client instead of creating new ones.

pool = await async_get_pool(client)
await connected(client)
conn = one(pool.conns)
await client.db.test.insert_one({"x": 1})
self.addAsyncCleanup(client.db.test.drop)

async def task():
await client.db.test.find_one({"$where": delay(0.2)})

task = asyncio.create_task(task())

await asyncio.sleep(0.1)

task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertTrue(conn.closed)

@async_client_context.require_transactions
async def test_async_cancellation_aborts_transaction(self):
client = await self.async_rs_or_single_client()
await connected(client)
await client.db.test.insert_one({"x": 1})
self.addAsyncCleanup(client.db.test.drop)

session = client.start_session()

async def callback(session):
await client.db.test.find_one({"$where": delay(0.2)}, session=session)

async def task():
await session.with_transaction(callback)

task = asyncio.create_task(task())

await asyncio.sleep(0.1)

task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertFalse(session.in_transaction)

@async_client_context.require_failCommand_blockConnection
async def test_async_cancellation_closes_cursor(self):
client = await self.async_rs_or_single_client()
await connected(client)
for _ in range(2):
await client.db.test.insert_one({"x": 1})
self.addAsyncCleanup(client.db.test.drop)

cursor = client.db.test.find({}, batch_size=1)
await cursor.next()

# Make sure getMore commands block
fail_command = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
}

async def task():
async with self.fail_point(fail_command):
await cursor.next()

task = asyncio.create_task(task())

await asyncio.sleep(0.1)

task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertTrue(cursor._killed)

@async_client_context.require_change_streams
@async_client_context.require_failCommand_blockConnection
async def test_async_cancellation_closes_change_stream(self):
client = await self.async_rs_or_single_client()
await connected(client)
self.addAsyncCleanup(client.db.test.drop)

change_stream = await client.db.test.watch(batch_size=2)

# Make sure getMore commands block
fail_command = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
}

async def task():
async with self.fail_point(fail_command):
for _ in range(2):
await client.db.test.insert_one({"x": 1})
await change_stream.next()

task = asyncio.create_task(task())

await asyncio.sleep(0.1)

task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertTrue(change_stream._closed)
2 changes: 1 addition & 1 deletion tools/synchro.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@

def async_only_test(f: str) -> bool:
"""Return True for async tests that should not be converted to sync."""
return f in ["test_locks.py", "test_concurrency.py"]
return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"]


test_files = [
Expand Down
Loading