1515"""Test the client_session module."""
1616from __future__ import annotations
1717
18+ import asyncio
1819import copy
1920import sys
2021import time
22+ from asyncio import iscoroutinefunction
2123from io import BytesIO
24+ from test .asynchronous .helpers import ExceptionCatchingTask
2225from typing import Any , Callable , List , Set , Tuple
2326
2427from pymongo .synchronous .mongo_client import MongoClient
3538)
3639from test .utils import (
3740 EventListener ,
38- ExceptionCatchingThread ,
3941 OvertCommandListener ,
4042 async_wait_until ,
4143)
@@ -184,16 +186,15 @@ async def _test_ops(self, client, *ops):
184186 f"{ f .__name__ } did not return implicit session to pool" ,
185187 )
186188
187- @async_client_context .require_sync
188- def test_implicit_sessions_checkout (self ):
189+ async def test_implicit_sessions_checkout (self ):
189190 # "To confirm that implicit sessions only allocate their server session after a
190191 # successful connection checkout" test from Driver Sessions Spec.
191192 succeeded = False
192193 lsid_set = set ()
193194 failures = 0
194195 for _ in range (5 ):
195196 listener = OvertCommandListener ()
196- client = self .async_rs_or_single_client (event_listeners = [listener ], maxPoolSize = 1 )
197+ client = await self .async_rs_or_single_client (event_listeners = [listener ], maxPoolSize = 1 )
197198 cursor = client .db .test .find ({})
198199 ops : List [Tuple [Callable , List [Any ]]] = [
199200 (client .db .test .find_one , [{"_id" : 1 }]),
@@ -210,26 +211,27 @@ def test_implicit_sessions_checkout(self):
210211 (cursor .distinct , ["_id" ]),
211212 (client .db .list_collections , []),
212213 ]
213- threads = []
214+ tasks = []
214215 listener .reset ()
215216
216- def thread_target (op , * args ):
217- res = op (* args )
217+ async def target (op , * args ):
218+ if iscoroutinefunction (op ):
219+ res = await op (* args )
220+ else :
221+ res = op (* args )
218222 if isinstance (res , (AsyncCursor , AsyncCommandCursor )):
219- list ( res ) # type: ignore[call-overload]
223+ await res . to_list ()
220224
221225 for op , args in ops :
222- threads .append (
223- ExceptionCatchingThread (
224- target = thread_target , args = [op , * args ], name = op .__name__
225- )
226+ tasks .append (
227+ ExceptionCatchingTask (target = target , args = [op , * args ], name = op .__name__ )
226228 )
227- threads [- 1 ].start ()
228- self .assertEqual (len (threads ), len (ops ))
229- for thread in threads :
230- thread .join ()
231- self .assertIsNone (thread .exc )
232- client .close ()
229+ await tasks [- 1 ].start ()
230+ self .assertEqual (len (tasks ), len (ops ))
231+ for t in tasks :
232+ await t .join ()
233+ self .assertIsNone (t .exc )
234+ await client .close ()
233235 lsid_set .clear ()
234236 for i in listener .started_events :
235237 if i .command .get ("lsid" ):
0 commit comments