Skip to content

Commit 12e2e39

Browse files
committed
WIP
1 parent 51a69f5 commit 12e2e39

File tree

5 files changed

+64
-56
lines changed

5 files changed

+64
-56
lines changed

test/asynchronous/helpers.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -381,14 +381,16 @@ def disable(self):
381381

382382

383383
class ConcurrentRunner(PARENT):
384-
def __init__(self, name, *args, **kwargs):
384+
def __init__(self, **kwargs):
385385
if _IS_SYNC:
386-
super().__init__(*args, **kwargs)
387-
self.name = name
386+
super().__init__(**kwargs)
387+
self.name = kwargs.get("name", "ConcurrentRunner")
388388
self.stopped = False
389389
self.task = None
390390
if "target" in kwargs:
391391
self.target = kwargs["target"]
392+
if "args" in kwargs:
393+
self.args = kwargs["args"]
392394

393395
if not _IS_SYNC:
394396

@@ -407,23 +409,23 @@ async def run(self):
407409
if _IS_SYNC:
408410
super().run()
409411
else:
410-
await self.target()
412+
if self.args:
413+
await self.target(*self.args)
414+
else:
415+
await self.target()
411416
self.stopped = True
412417

413418

414419
class ExceptionCatchingTask(ConcurrentRunner):
415420
"""A Task that stores any exception encountered while running."""
416421

417-
def __init__(self, *args, **kwargs):
418-
super().__init__("ExceptionCatchingTask", *args, **kwargs)
422+
def __init__(self, **kwargs):
423+
super().__init__(**kwargs)
419424
self.exc = None
420425

421426
async def run(self):
422427
try:
423-
if _IS_SYNC:
424-
await super().run()
425-
else:
426-
await self.target()
428+
await super().run()
427429
except BaseException as exc:
428430
self.exc = exc
429431
raise

test/asynchronous/test_session.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
"""Test the client_session module."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import copy
1920
import sys
2021
import time
22+
from asyncio import iscoroutinefunction
2123
from io import BytesIO
24+
from test.asynchronous.helpers import ExceptionCatchingTask
2225
from typing import Any, Callable, List, Set, Tuple
2326

2427
from pymongo.synchronous.mongo_client import MongoClient
@@ -35,7 +38,6 @@
3538
)
3639
from test.utils import (
3740
EventListener,
38-
ExceptionCatchingThread,
3941
OvertCommandListener,
4042
async_wait_until,
4143
)
@@ -184,16 +186,15 @@ async def _test_ops(self, client, *ops):
184186
f"{f.__name__} did not return implicit session to pool",
185187
)
186188

187-
@async_client_context.require_sync
188-
def test_implicit_sessions_checkout(self):
189+
async def test_implicit_sessions_checkout(self):
189190
# "To confirm that implicit sessions only allocate their server session after a
190191
# successful connection checkout" test from Driver Sessions Spec.
191192
succeeded = False
192193
lsid_set = set()
193194
failures = 0
194195
for _ in range(5):
195196
listener = OvertCommandListener()
196-
client = self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
197+
client = await self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
197198
cursor = client.db.test.find({})
198199
ops: List[Tuple[Callable, List[Any]]] = [
199200
(client.db.test.find_one, [{"_id": 1}]),
@@ -210,26 +211,27 @@ def test_implicit_sessions_checkout(self):
210211
(cursor.distinct, ["_id"]),
211212
(client.db.list_collections, []),
212213
]
213-
threads = []
214+
tasks = []
214215
listener.reset()
215216

216-
def thread_target(op, *args):
217-
res = op(*args)
217+
async def target(op, *args):
218+
if iscoroutinefunction(op):
219+
res = await op(*args)
220+
else:
221+
res = op(*args)
218222
if isinstance(res, (AsyncCursor, AsyncCommandCursor)):
219-
list(res) # type: ignore[call-overload]
223+
await res.to_list()
220224

221225
for op, args in ops:
222-
threads.append(
223-
ExceptionCatchingThread(
224-
target=thread_target, args=[op, *args], name=op.__name__
225-
)
226+
tasks.append(
227+
ExceptionCatchingTask(target=target, args=[op, *args], name=op.__name__)
226228
)
227-
threads[-1].start()
228-
self.assertEqual(len(threads), len(ops))
229-
for thread in threads:
230-
thread.join()
231-
self.assertIsNone(thread.exc)
232-
client.close()
229+
await tasks[-1].start()
230+
self.assertEqual(len(tasks), len(ops))
231+
for t in tasks:
232+
await t.join()
233+
self.assertIsNone(t.exc)
234+
await client.close()
233235
lsid_set.clear()
234236
for i in listener.started_events:
235237
if i.command.get("lsid"):

test/helpers.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -381,14 +381,16 @@ def disable(self):
381381

382382

383383
class ConcurrentRunner(PARENT):
384-
def __init__(self, name, *args, **kwargs):
384+
def __init__(self, **kwargs):
385385
if _IS_SYNC:
386-
super().__init__(*args, **kwargs)
387-
self.name = name
386+
super().__init__(**kwargs)
387+
self.name = kwargs.get("name", "ConcurrentRunner")
388388
self.stopped = False
389389
self.task = None
390390
if "target" in kwargs:
391391
self.target = kwargs["target"]
392+
if "args" in kwargs:
393+
self.args = kwargs["args"]
392394

393395
if not _IS_SYNC:
394396

@@ -407,23 +409,23 @@ def run(self):
407409
if _IS_SYNC:
408410
super().run()
409411
else:
410-
self.target()
412+
if self.args:
413+
self.target(*self.args)
414+
else:
415+
self.target()
411416
self.stopped = True
412417

413418

414419
class ExceptionCatchingTask(ConcurrentRunner):
415420
"""A Task that stores any exception encountered while running."""
416421

417-
def __init__(self, *args, **kwargs):
418-
super().__init__("ExceptionCatchingTask", *args, **kwargs)
422+
def __init__(self, **kwargs):
423+
super().__init__(**kwargs)
419424
self.exc = None
420425

421426
def run(self):
422427
try:
423-
if _IS_SYNC:
424-
super().run()
425-
else:
426-
self.target()
428+
super().run()
427429
except BaseException as exc:
428430
self.exc = exc
429431
raise

test/test_bson.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
sys.path[0:0] = [""]
3434

3535
from test import qcheck, unittest
36-
from test.utils import ExceptionCatchingThread
36+
from test.helpers import ExceptionCatchingTask
3737

3838
import bson
3939
from bson import (
@@ -1075,7 +1075,7 @@ def target(i):
10751075
my_int = type(f"MyInt_{i}_{j}", (int,), {})
10761076
bson.encode({"my_int": my_int()})
10771077

1078-
threads = [ExceptionCatchingThread(target=target, args=(i,)) for i in range(3)]
1078+
threads = [ExceptionCatchingTask(target=target, args=(i,)) for i in range(3)]
10791079
for t in threads:
10801080
t.start()
10811081

@@ -1114,7 +1114,7 @@ def __repr__(self):
11141114

11151115
def test_doc_in_invalid_document_error_message_mapping(self):
11161116
class MyMapping(abc.Mapping):
1117-
def keys():
1117+
def keys(self):
11181118
return ["t"]
11191119

11201120
def __getitem__(self, name):

test/test_session.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
"""Test the client_session module."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import copy
1920
import sys
2021
import time
22+
from asyncio import iscoroutinefunction
2123
from io import BytesIO
24+
from test.helpers import ExceptionCatchingTask
2225
from typing import Any, Callable, List, Set, Tuple
2326

2427
from pymongo.synchronous.mongo_client import MongoClient
@@ -35,7 +38,6 @@
3538
)
3639
from test.utils import (
3740
EventListener,
38-
ExceptionCatchingThread,
3941
OvertCommandListener,
4042
wait_until,
4143
)
@@ -184,7 +186,6 @@ def _test_ops(self, client, *ops):
184186
f"{f.__name__} did not return implicit session to pool",
185187
)
186188

187-
@client_context.require_sync
188189
def test_implicit_sessions_checkout(self):
189190
# "To confirm that implicit sessions only allocate their server session after a
190191
# successful connection checkout" test from Driver Sessions Spec.
@@ -210,25 +211,26 @@ def test_implicit_sessions_checkout(self):
210211
(cursor.distinct, ["_id"]),
211212
(client.db.list_collections, []),
212213
]
213-
threads = []
214+
tasks = []
214215
listener.reset()
215216

216-
def thread_target(op, *args):
217-
res = op(*args)
217+
def target(op, *args):
218+
if iscoroutinefunction(op):
219+
res = op(*args)
220+
else:
221+
res = op(*args)
218222
if isinstance(res, (Cursor, CommandCursor)):
219-
list(res) # type: ignore[call-overload]
223+
res.to_list()
220224

221225
for op, args in ops:
222-
threads.append(
223-
ExceptionCatchingThread(
224-
target=thread_target, args=[op, *args], name=op.__name__
225-
)
226+
tasks.append(
227+
ExceptionCatchingTask(target=target, args=[op, *args], name=op.__name__)
226228
)
227-
threads[-1].start()
228-
self.assertEqual(len(threads), len(ops))
229-
for thread in threads:
230-
thread.join()
231-
self.assertIsNone(thread.exc)
229+
tasks[-1].start()
230+
self.assertEqual(len(tasks), len(ops))
231+
for t in tasks:
232+
t.join()
233+
self.assertIsNone(t.exc)
232234
client.close()
233235
lsid_set.clear()
234236
for i in listener.started_events:

0 commit comments

Comments
 (0)