Skip to content

Commit fd18b7a

Browse files
committed
Verify that connections, transactions, and cursors are terminated when cancelled
1 parent 91eb68f commit fd18b7a

File tree

5 files changed

+43
-11
lines changed

5 files changed

+43
-11
lines changed

pymongo/asynchronous/client_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ async def callback(session, custom_arg, custom_kwarg=None):
697697
)
698698
try:
699699
ret = await callback(self)
700-
except Exception as exc:
700+
except BaseException as exc:
701701
if self.in_transaction:
702702
await self.abort_transaction()
703703
if (

pymongo/asynchronous/cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,7 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
11261126
self._killed = True
11271127
await self.close()
11281128
raise
1129-
except Exception:
1129+
except BaseException:
11301130
await self.close()
11311131
raise
11321132

pymongo/synchronous/client_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def callback(session, custom_arg, custom_kwarg=None):
694694
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
695695
try:
696696
ret = callback(self)
697-
except Exception as exc:
697+
except BaseException as exc:
698698
if self.in_transaction:
699699
self.abort_transaction()
700700
if (

pymongo/synchronous/cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,7 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
11241124
self._killed = True
11251125
self.close()
11261126
raise
1127-
except Exception:
1127+
except BaseException:
11281128
self.close()
11291129
raise
11301130

test/asynchronous/test_async_cancellation.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,24 @@
1717

1818
import asyncio
1919
import sys
20-
import traceback
21-
22-
from test.utils import async_get_pool, get_pool, one, delay
20+
from test.utils import async_get_pool, delay, one
2321

2422
sys.path[0:0] = [""]
2523

26-
from test.asynchronous import AsyncIntegrationTest, connected, async_client_context
24+
from test.asynchronous import AsyncIntegrationTest, async_client_context, connected
2725

2826

2927
class TestAsyncCancellation(AsyncIntegrationTest):
3028
async def test_async_cancellation_closes_connection(self):
31-
client = await self.async_rs_or_single_client()
29+
client = await self.async_rs_or_single_client(maxPoolSize=1)
3230
pool = await async_get_pool(client)
3331
await connected(client)
3432
conn = one(pool.conns)
33+
await client.db.test.insert_one({"x": 1})
34+
self.addAsyncCleanup(client.db.test.drop)
3535

3636
async def task():
37-
await client.db.test.find_one({"$where": delay(1.0)})
37+
await client.db.test.find_one({"$where": delay(0.2)})
3838

3939
task = asyncio.create_task(task())
4040

@@ -50,11 +50,13 @@ async def task():
5050
async def test_async_cancellation_aborts_transaction(self):
5151
client = await self.async_rs_or_single_client()
5252
await connected(client)
53+
await client.db.test.insert_one({"x": 1})
54+
self.addAsyncCleanup(client.db.test.drop)
5355

5456
session = client.start_session()
5557

5658
async def callback(session):
57-
await client.db.test.find_one({"$where": delay(1.0)})
59+
await client.db.test.find_one({"$where": delay(0.2)}, session=session)
5860

5961
async def task():
6062
await session.with_transaction(callback)
@@ -69,3 +71,33 @@ async def task():
6971

7072
self.assertFalse(session.in_transaction)
7173

74+
async def test_async_cancellation_kills_cursor(self):
75+
client = await self.async_rs_or_single_client()
76+
await connected(client)
77+
for _ in range(2):
78+
await client.db.test.insert_one({"x": 1})
79+
self.addAsyncCleanup(client.db.test.drop)
80+
81+
cursor = client.db.test.find({}, batch_size=1)
82+
await cursor.next()
83+
84+
# Make sure getMore commands block
85+
fail_command = {
86+
"configureFailPoint": "failCommand",
87+
"mode": "alwaysOn",
88+
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
89+
}
90+
91+
async def task():
92+
async with self.fail_point(fail_command):
93+
await cursor.next()
94+
95+
task = asyncio.create_task(task())
96+
97+
await asyncio.sleep(0.1)
98+
99+
task.cancel()
100+
with self.assertRaises(asyncio.CancelledError):
101+
await task
102+
103+
self.assertTrue(cursor._killed)

0 commit comments

Comments
 (0)