2121import pathlib
2222import sys
2323import threading
24+ from asyncio import Event
25+ from test .asynchronous .helpers import ConcurrentRunner , ExceptionCatchingTask
2426
2527import pytest
2628
2931from test .asynchronous import AsyncIntegrationTest , async_client_context , unittest
3032from test .asynchronous .unified_format import generate_test_classes
3133from test .utils import (
32- ExceptionCatchingTask ,
33- ExceptionCatchingThread ,
3434 async_get_pool ,
3535 async_wait_until ,
36+ create_async_event ,
3637)
3738
3839from pymongo .asynchronous .helpers import anext
@@ -116,35 +117,19 @@ async def _test_no_gc_deadlock(self, create_resource):
116117 if async_client_context .load_balancer :
117118 self .assertEqual (pool .active_sockets , 1 ) # Pinned.
118119
119- if _IS_SYNC :
120- thread = PoolLocker (pool )
121- thread .start ()
122- self .assertTrue (thread .locked .wait (5 ), "timed out" )
123- # Garbage collect the resource while the pool is locked to ensure we
124- # don't deadlock.
125- del resource
126- # On PyPy it can take a few rounds to collect the cursor.
127- for _ in range (3 ):
128- gc .collect ()
129- thread .unlock .set ()
130- thread .join (5 )
131- self .assertFalse (thread .is_alive ())
132- self .assertIsNone (thread .exc )
133-
134- else :
135- task = PoolLocker (pool )
136- self .assertTrue (await asyncio .wait_for (task .locked .wait (), timeout = 5 ), "timed out" ) # type: ignore[arg-type]
137-
138- # Garbage collect the resource while the pool is locked to ensure we
139- # don't deadlock.
140- del resource
141- # On PyPy it can take a few rounds to collect the cursor.
142- for _ in range (3 ):
143- gc .collect ()
144- task .unlock .set ()
145- await task .run ()
146- self .assertFalse (task .is_alive ())
147- self .assertIsNone (task .exc )
120+ task = PoolLocker (pool )
121+ await task .start ()
122+ self .assertTrue (await task .wait (task .locked , 5 ), "timed out" )
123+ # Garbage collect the resource while the pool is locked to ensure we
124+ # don't deadlock.
125+ del resource
126+ # On PyPy it can take a few rounds to collect the cursor.
127+ for _ in range (3 ):
128+ gc .collect ()
129+ task .unlock .set ()
130+ await task .join (5 )
131+ self .assertFalse (task .is_alive ())
132+ self .assertIsNone (task .exc )
148133
149134 await async_wait_until (lambda : pool .active_sockets == 0 , "return socket" )
150135 # Run another operation to ensure the socket still works.
@@ -164,80 +149,50 @@ async def test_session_gc(self):
164149 if async_client_context .load_balancer :
165150 self .assertEqual (pool .active_sockets , 1 ) # Pinned.
166151
167- if _IS_SYNC :
168- thread = PoolLocker (pool )
169- thread .start ()
170- self .assertTrue (thread .locked .wait (5 ), "timed out" )
171- # Garbage collect the session while the pool is locked to ensure we
172- # don't deadlock.
173- del session
174- # On PyPy it can take a few rounds to collect the session.
175- for _ in range (3 ):
176- gc .collect ()
177- thread .unlock .set ()
178- thread .join (5 )
179- self .assertFalse (thread .is_alive ())
180- self .assertIsNone (thread .exc )
181-
182- else :
183- task = PoolLocker (pool )
184- self .assertTrue (await asyncio .wait_for (task .locked .wait (), timeout = 5 ), "timed out" ) # type: ignore[arg-type]
185-
186- # Garbage collect the session while the pool is locked to ensure we
187- # don't deadlock.
188- del session
189- # On PyPy it can take a few rounds to collect the cursor.
190- for _ in range (3 ):
191- gc .collect ()
192- task .unlock .set ()
193- await task .run ()
194- self .assertFalse (task .is_alive ())
195- self .assertIsNone (task .exc )
152+ task = PoolLocker (pool )
153+ await task .start ()
154+ self .assertTrue (await task .wait (task .locked , 5 ), "timed out" )
155+ # Garbage collect the session while the pool is locked to ensure we
156+ # don't deadlock.
157+ del session
158+ # On PyPy it can take a few rounds to collect the session.
159+ for _ in range (3 ):
160+ gc .collect ()
161+ task .unlock .set ()
162+ await task .join (5 )
163+ self .assertFalse (task .is_alive ())
164+ self .assertIsNone (task .exc )
196165
197166 await async_wait_until (lambda : pool .active_sockets == 0 , "return socket" )
198167 # Run another operation to ensure the socket still works.
199168 await client [self .db .name ].test .delete_many ({})
200169
201170
202- if _IS_SYNC :
203-
204- class PoolLocker (ExceptionCatchingThread ):
205- def __init__ (self , pool ):
206- super ().__init__ (target = self .lock_pool )
207- self .pool = pool
208- self .daemon = True
209- self .locked = threading .Event ()
210- self .unlock = threading .Event ()
211-
212- def lock_pool (self ):
213- with self .pool .lock :
214- self .locked .set ()
215- # Wait for the unlock flag.
216- unlock_pool = self .unlock .wait (10 )
217- if not unlock_pool :
218- raise Exception ("timed out waiting for unlock signal: deadlock?" )
171+ class PoolLocker (ExceptionCatchingTask ):
172+ def __init__ (self , pool ):
173+ super ().__init__ (target = self .lock_pool )
174+ self .pool = pool
175+ self .daemon = True
176+ self .locked = create_async_event ()
177+ self .unlock = create_async_event ()
219178
220- else :
179+ async def lock_pool (self ):
180+ async with self .pool .lock :
181+ self .locked .set ()
182+ # Wait for the unlock flag.
183+ unlock_pool = await self .wait (self .unlock , 10 )
184+ if not unlock_pool :
185+ raise Exception ("timed out waiting for unlock signal: deadlock?" )
221186
222- class PoolLocker (ExceptionCatchingTask ):
223- def __init__ (self , pool ):
224- super ().__init__ (self .lock_pool )
225- self .pool = pool
226- self .daemon = True
227- self .locked = asyncio .Event ()
228- self .unlock = asyncio .Event ()
229-
230- async def lock_pool (self ):
231- async with self .pool .lock :
232- self .locked .set ()
233- # Wait for the unlock flag.
234- try :
235- await asyncio .wait_for (self .unlock .wait (), timeout = 10 )
236- except asyncio .TimeoutError :
237- raise Exception ("timed out waiting for unlock signal: deadlock?" )
238-
239- def is_alive (self ):
240- return not self .task .done ()
187+ async def wait (self , event : Event , timeout : int ):
188+ if _IS_SYNC :
189+ return event .wait (timeout )
190+ else :
191+ try :
192+ await asyncio .wait_for (event .wait (), timeout = timeout )
193+ except asyncio .TimeoutError :
194+ return False
195+ return True
241196
242197
243198if __name__ == "__main__" :
0 commit comments