Skip to content

Commit 51a69f5

Browse files
committed
Refactor
1 parent d9a494c commit 51a69f5

File tree

6 files changed

+161
-227
lines changed

6 files changed

+161
-227
lines changed

test/asynchronous/helpers.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,5 +404,28 @@ def is_alive(self):
404404

405405
async def run(self):
406406
if self.target:
407-
await self.target()
407+
if _IS_SYNC:
408+
super().run()
409+
else:
410+
await self.target()
408411
self.stopped = True
412+
413+
414+
class ExceptionCatchingTask(ConcurrentRunner):
415+
"""A Task that stores any exception encountered while running."""
416+
417+
def __init__(self, *args, **kwargs):
418+
super().__init__("ExceptionCatchingTask", *args, **kwargs)
419+
self.exc = None
420+
421+
async def run(self):
422+
try:
423+
if _IS_SYNC:
424+
await super().run()
425+
else:
426+
await self.target()
427+
except BaseException as exc:
428+
self.exc = exc
429+
raise
430+
finally:
431+
self.stopped = True

test/asynchronous/test_load_balancer.py

Lines changed: 52 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import pathlib
2222
import sys
2323
import threading
24+
from asyncio import Event
25+
from test.asynchronous.helpers import ConcurrentRunner, ExceptionCatchingTask
2426

2527
import pytest
2628

@@ -29,10 +31,9 @@
2931
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
3032
from test.asynchronous.unified_format import generate_test_classes
3133
from test.utils import (
32-
ExceptionCatchingTask,
33-
ExceptionCatchingThread,
3434
async_get_pool,
3535
async_wait_until,
36+
create_async_event,
3637
)
3738

3839
from pymongo.asynchronous.helpers import anext
@@ -116,35 +117,19 @@ async def _test_no_gc_deadlock(self, create_resource):
116117
if async_client_context.load_balancer:
117118
self.assertEqual(pool.active_sockets, 1) # Pinned.
118119

119-
if _IS_SYNC:
120-
thread = PoolLocker(pool)
121-
thread.start()
122-
self.assertTrue(thread.locked.wait(5), "timed out")
123-
# Garbage collect the resource while the pool is locked to ensure we
124-
# don't deadlock.
125-
del resource
126-
# On PyPy it can take a few rounds to collect the cursor.
127-
for _ in range(3):
128-
gc.collect()
129-
thread.unlock.set()
130-
thread.join(5)
131-
self.assertFalse(thread.is_alive())
132-
self.assertIsNone(thread.exc)
133-
134-
else:
135-
task = PoolLocker(pool)
136-
self.assertTrue(await asyncio.wait_for(task.locked.wait(), timeout=5), "timed out") # type: ignore[arg-type]
137-
138-
# Garbage collect the resource while the pool is locked to ensure we
139-
# don't deadlock.
140-
del resource
141-
# On PyPy it can take a few rounds to collect the cursor.
142-
for _ in range(3):
143-
gc.collect()
144-
task.unlock.set()
145-
await task.run()
146-
self.assertFalse(task.is_alive())
147-
self.assertIsNone(task.exc)
120+
task = PoolLocker(pool)
121+
await task.start()
122+
self.assertTrue(await task.wait(task.locked, 5), "timed out")
123+
# Garbage collect the resource while the pool is locked to ensure we
124+
# don't deadlock.
125+
del resource
126+
# On PyPy it can take a few rounds to collect the cursor.
127+
for _ in range(3):
128+
gc.collect()
129+
task.unlock.set()
130+
await task.join(5)
131+
self.assertFalse(task.is_alive())
132+
self.assertIsNone(task.exc)
148133

149134
await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
150135
# Run another operation to ensure the socket still works.
@@ -164,80 +149,50 @@ async def test_session_gc(self):
164149
if async_client_context.load_balancer:
165150
self.assertEqual(pool.active_sockets, 1) # Pinned.
166151

167-
if _IS_SYNC:
168-
thread = PoolLocker(pool)
169-
thread.start()
170-
self.assertTrue(thread.locked.wait(5), "timed out")
171-
# Garbage collect the session while the pool is locked to ensure we
172-
# don't deadlock.
173-
del session
174-
# On PyPy it can take a few rounds to collect the session.
175-
for _ in range(3):
176-
gc.collect()
177-
thread.unlock.set()
178-
thread.join(5)
179-
self.assertFalse(thread.is_alive())
180-
self.assertIsNone(thread.exc)
181-
182-
else:
183-
task = PoolLocker(pool)
184-
self.assertTrue(await asyncio.wait_for(task.locked.wait(), timeout=5), "timed out") # type: ignore[arg-type]
185-
186-
# Garbage collect the session while the pool is locked to ensure we
187-
# don't deadlock.
188-
del session
189-
# On PyPy it can take a few rounds to collect the cursor.
190-
for _ in range(3):
191-
gc.collect()
192-
task.unlock.set()
193-
await task.run()
194-
self.assertFalse(task.is_alive())
195-
self.assertIsNone(task.exc)
152+
task = PoolLocker(pool)
153+
await task.start()
154+
self.assertTrue(await task.wait(task.locked, 5), "timed out")
155+
# Garbage collect the session while the pool is locked to ensure we
156+
# don't deadlock.
157+
del session
158+
# On PyPy it can take a few rounds to collect the session.
159+
for _ in range(3):
160+
gc.collect()
161+
task.unlock.set()
162+
await task.join(5)
163+
self.assertFalse(task.is_alive())
164+
self.assertIsNone(task.exc)
196165

197166
await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
198167
# Run another operation to ensure the socket still works.
199168
await client[self.db.name].test.delete_many({})
200169

201170

202-
if _IS_SYNC:
203-
204-
class PoolLocker(ExceptionCatchingThread):
205-
def __init__(self, pool):
206-
super().__init__(target=self.lock_pool)
207-
self.pool = pool
208-
self.daemon = True
209-
self.locked = threading.Event()
210-
self.unlock = threading.Event()
211-
212-
def lock_pool(self):
213-
with self.pool.lock:
214-
self.locked.set()
215-
# Wait for the unlock flag.
216-
unlock_pool = self.unlock.wait(10)
217-
if not unlock_pool:
218-
raise Exception("timed out waiting for unlock signal: deadlock?")
171+
class PoolLocker(ExceptionCatchingTask):
172+
def __init__(self, pool):
173+
super().__init__(target=self.lock_pool)
174+
self.pool = pool
175+
self.daemon = True
176+
self.locked = create_async_event()
177+
self.unlock = create_async_event()
219178

220-
else:
179+
async def lock_pool(self):
180+
async with self.pool.lock:
181+
self.locked.set()
182+
# Wait for the unlock flag.
183+
unlock_pool = await self.wait(self.unlock, 10)
184+
if not unlock_pool:
185+
raise Exception("timed out waiting for unlock signal: deadlock?")
221186

222-
class PoolLocker(ExceptionCatchingTask):
223-
def __init__(self, pool):
224-
super().__init__(self.lock_pool)
225-
self.pool = pool
226-
self.daemon = True
227-
self.locked = asyncio.Event()
228-
self.unlock = asyncio.Event()
229-
230-
async def lock_pool(self):
231-
async with self.pool.lock:
232-
self.locked.set()
233-
# Wait for the unlock flag.
234-
try:
235-
await asyncio.wait_for(self.unlock.wait(), timeout=10)
236-
except asyncio.TimeoutError:
237-
raise Exception("timed out waiting for unlock signal: deadlock?")
238-
239-
def is_alive(self):
240-
return not self.task.done()
187+
async def wait(self, event: Event, timeout: int):
188+
if _IS_SYNC:
189+
return event.wait(timeout)
190+
else:
191+
try:
192+
await asyncio.wait_for(event.wait(), timeout=timeout)
193+
except asyncio.TimeoutError:
194+
return False
195+
return True
241196

242197

243198
if __name__ == "__main__":

test/helpers.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,5 +404,28 @@ def is_alive(self):
404404

405405
def run(self):
406406
if self.target:
407-
self.target()
407+
if _IS_SYNC:
408+
super().run()
409+
else:
410+
self.target()
408411
self.stopped = True
412+
413+
414+
class ExceptionCatchingTask(ConcurrentRunner):
415+
"""A Task that stores any exception encountered while running."""
416+
417+
def __init__(self, *args, **kwargs):
418+
super().__init__("ExceptionCatchingTask", *args, **kwargs)
419+
self.exc = None
420+
421+
def run(self):
422+
try:
423+
if _IS_SYNC:
424+
super().run()
425+
else:
426+
self.target()
427+
except BaseException as exc:
428+
self.exc = exc
429+
raise
430+
finally:
431+
self.stopped = True

0 commit comments

Comments
 (0)