1717
1818import asyncio
1919import sys
20- import traceback
21-
22- from test .utils import async_get_pool , get_pool , one , delay
20+ from test .utils import async_get_pool , delay , one
2321
2422sys .path [0 :0 ] = ["" ]
2523
26- from test .asynchronous import AsyncIntegrationTest , connected , async_client_context
24+ from test .asynchronous import AsyncIntegrationTest , async_client_context , connected
2725
2826
2927class TestAsyncCancellation (AsyncIntegrationTest ):
3028 async def test_async_cancellation_closes_connection (self ):
31- client = await self .async_rs_or_single_client ()
29+ client = await self .async_rs_or_single_client (maxPoolSize = 1 )
3230 pool = await async_get_pool (client )
3331 await connected (client )
3432 conn = one (pool .conns )
33+ await client .db .test .insert_one ({"x" : 1 })
34+ self .addAsyncCleanup (client .db .test .drop )
3535
3636 async def task ():
37- await client .db .test .find_one ({"$where" : delay (1.0 )})
37+ await client .db .test .find_one ({"$where" : delay (0.2 )})
3838
3939 task = asyncio .create_task (task ())
4040
@@ -50,11 +50,13 @@ async def task():
5050 async def test_async_cancellation_aborts_transaction (self ):
5151 client = await self .async_rs_or_single_client ()
5252 await connected (client )
53+ await client .db .test .insert_one ({"x" : 1 })
54+ self .addAsyncCleanup (client .db .test .drop )
5355
5456 session = client .start_session ()
5557
5658 async def callback (session ):
57- await client .db .test .find_one ({"$where" : delay (1.0 )} )
59+ await client .db .test .find_one ({"$where" : delay (0.2 )}, session = session )
5860
5961 async def task ():
6062 await session .with_transaction (callback )
@@ -69,3 +71,33 @@ async def task():
6971
7072 self .assertFalse (session .in_transaction )
7173
74+ async def test_async_cancellation_kills_cursor (self ):
75+ client = await self .async_rs_or_single_client ()
76+ await connected (client )
77+ for _ in range (2 ):
78+ await client .db .test .insert_one ({"x" : 1 })
79+ self .addAsyncCleanup (client .db .test .drop )
80+
81+ cursor = client .db .test .find ({}, batch_size = 1 )
82+ await cursor .next ()
83+
84+ # Make sure getMore commands block
85+ fail_command = {
86+ "configureFailPoint" : "failCommand" ,
87+ "mode" : "alwaysOn" ,
88+ "data" : {"failCommands" : ["getMore" ], "blockConnection" : True , "blockTimeMS" : 200 },
89+ }
90+
91+ async def task ():
92+ async with self .fail_point (fail_command ):
93+ await cursor .next ()
94+
95+ task = asyncio .create_task (task ())
96+
97+ await asyncio .sleep (0.1 )
98+
99+ task .cancel ()
100+ with self .assertRaises (asyncio .CancelledError ):
101+ await task
102+
103+ self .assertTrue (cursor ._killed )
0 commit comments