Skip to content

Commit 4906897

Browse files
committed
Test change_stream cancellation
1 parent fd18b7a commit 4906897

File tree

3 files changed

+34
-4
lines changed

3 files changed

+34
-4
lines changed

pymongo/asynchronous/change_stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ async def try_next(self) -> Optional[_DocumentType]:
391391
if not _resumable(exc) and not exc.timeout:
392392
await self.close()
393393
raise
394-
except Exception:
394+
except BaseException:
395395
await self.close()
396396
raise
397397

pymongo/synchronous/change_stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def try_next(self) -> Optional[_DocumentType]:
389389
if not _resumable(exc) and not exc.timeout:
390390
self.close()
391391
raise
392-
except Exception:
392+
except BaseException:
393393
self.close()
394394
raise
395395

test/asynchronous/test_async_cancellation.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Test that async cancellation performed by users raises the expected error."""
15+
"""Test that async cancellation performed by users clean up resources correctly."""
1616
from __future__ import annotations
1717

1818
import asyncio
@@ -71,7 +71,7 @@ async def task():
7171

7272
self.assertFalse(session.in_transaction)
7373

74-
async def test_async_cancellation_kills_cursor(self):
74+
async def test_async_cancellation_closes_cursor(self):
7575
client = await self.async_rs_or_single_client()
7676
await connected(client)
7777
for _ in range(2):
@@ -101,3 +101,33 @@ async def task():
101101
await task
102102

103103
self.assertTrue(cursor._killed)
104+
105+
async def test_async_cancellation_closes_change_stream(self):
106+
client = await self.async_rs_or_single_client()
107+
await connected(client)
108+
self.addAsyncCleanup(client.db.test.drop)
109+
110+
change_stream = await client.db.test.watch(batch_size=2)
111+
112+
# Make sure getMore commands block
113+
fail_command = {
114+
"configureFailPoint": "failCommand",
115+
"mode": "alwaysOn",
116+
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
117+
}
118+
119+
async def task():
120+
async with self.fail_point(fail_command):
121+
for _ in range(2):
122+
await client.db.test.insert_one({"x": 1})
123+
await change_stream.next()
124+
125+
task = asyncio.create_task(task())
126+
127+
await asyncio.sleep(0.1)
128+
129+
task.cancel()
130+
with self.assertRaises(asyncio.CancelledError):
131+
await task
132+
133+
self.assertTrue(change_stream._closed)

0 commit comments

Comments
 (0)