Skip to content

Commit e455fa6

Browse files
committed
Use one loop for asyncio test suite
1 parent 78724cd commit e455fa6

File tree

2 files changed

+121
-47
lines changed

2 files changed

+121
-47
lines changed

test/__init__.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,6 @@
3030
import unittest
3131
import warnings
3232
from asyncio import iscoroutinefunction
33-
from test.helpers import (
34-
COMPRESSORS,
35-
IS_SRV,
36-
MONGODB_API_VERSION,
37-
MULTI_MONGOS_LB_URI,
38-
TEST_LOADBALANCER,
39-
TEST_SERVERLESS,
40-
TLS_OPTIONS,
41-
SystemCertsPatcher,
42-
client_knobs,
43-
db_pwd,
44-
db_user,
45-
global_knobs,
46-
host,
47-
is_server_resolvable,
48-
port,
49-
print_running_topology,
50-
print_thread_stacks,
51-
print_thread_tracebacks,
52-
sanitize_cmd,
53-
sanitize_reply,
54-
)
5533

5634
from pymongo.uri_parser import parse_uri
5735

@@ -63,7 +41,6 @@
6341
HAVE_IPADDRESS = False
6442
from contextlib import contextmanager
6543
from functools import partial, wraps
66-
from test.version import Version
6744
from typing import Any, Callable, Dict, Generator, overload
6845
from unittest import SkipTest
6946
from urllib.parse import quote_plus
@@ -78,6 +55,32 @@
7855
from pymongo.synchronous.database import Database
7956
from pymongo.synchronous.mongo_client import MongoClient
8057

58+
sys.path[0:0] = [""]
59+
60+
from test.helpers import (
61+
COMPRESSORS,
62+
IS_SRV,
63+
MONGODB_API_VERSION,
64+
MULTI_MONGOS_LB_URI,
65+
TEST_LOADBALANCER,
66+
TEST_SERVERLESS,
67+
TLS_OPTIONS,
68+
SystemCertsPatcher,
69+
client_knobs,
70+
db_pwd,
71+
db_user,
72+
global_knobs,
73+
host,
74+
is_server_resolvable,
75+
port,
76+
print_running_topology,
77+
print_thread_stacks,
78+
print_thread_tracebacks,
79+
sanitize_cmd,
80+
sanitize_reply,
81+
)
82+
from test.version import Version
83+
8184
_IS_SYNC = True
8285

8386

@@ -875,6 +878,40 @@ def reset_client_context():
875878

876879

877880
class PyMongoTestCase(unittest.TestCase):
881+
if not _IS_SYNC:
882+
# Customize async TestCase to use a single event loop for all tests.
883+
def __init__(self, methodName="runTest"):
884+
super().__init__(methodName)
885+
try:
886+
self.loop = asyncio.get_event_loop()
887+
except RuntimeError:
888+
self.loop = asyncio.new_event_loop()
889+
asyncio.set_event_loop(self.loop)
890+
891+
def setUp(self):
892+
pass
893+
894+
def tearDown(self):
895+
pass
896+
897+
# See TestCase.addCleanup.
898+
def addCleanup(self, func, /, *args, **kwargs):
899+
self.addCleanup(*(func, *args), **kwargs)
900+
901+
def run(self, result=None):
902+
if result is None:
903+
result = self.defaultTestResult()
904+
result.startTest(self)
905+
906+
try:
907+
self.setUp()
908+
self.loop.run_until_complete(self.setUp())
909+
self.loop.run_until_complete(getattr(self, self._testMethodName)())
910+
self.loop.run_until_complete(self.tearDown())
911+
finally:
912+
self.tearDown()
913+
result.stopTest(self)
914+
878915
def assertEqualCommand(self, expected, actual, msg=None):
879916
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
880917

test/asynchronous/__init__.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,6 @@
3030
import unittest
3131
import warnings
3232
from asyncio import iscoroutinefunction
33-
from test.helpers import (
34-
COMPRESSORS,
35-
IS_SRV,
36-
MONGODB_API_VERSION,
37-
MULTI_MONGOS_LB_URI,
38-
TEST_LOADBALANCER,
39-
TEST_SERVERLESS,
40-
TLS_OPTIONS,
41-
SystemCertsPatcher,
42-
client_knobs,
43-
db_pwd,
44-
db_user,
45-
global_knobs,
46-
host,
47-
is_server_resolvable,
48-
port,
49-
print_running_topology,
50-
print_thread_stacks,
51-
print_thread_tracebacks,
52-
sanitize_cmd,
53-
sanitize_reply,
54-
)
5533

5634
from pymongo.uri_parser import parse_uri
5735

@@ -63,7 +41,6 @@
6341
HAVE_IPADDRESS = False
6442
from contextlib import asynccontextmanager, contextmanager
6543
from functools import partial, wraps
66-
from test.version import Version
6744
from typing import Any, Callable, Dict, Generator, overload
6845
from unittest import SkipTest
6946
from urllib.parse import quote_plus
@@ -78,6 +55,32 @@
7855
from pymongo.server_api import ServerApi
7956
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
8057

58+
sys.path[0:0] = [""]
59+
60+
from test.helpers import (
61+
COMPRESSORS,
62+
IS_SRV,
63+
MONGODB_API_VERSION,
64+
MULTI_MONGOS_LB_URI,
65+
TEST_LOADBALANCER,
66+
TEST_SERVERLESS,
67+
TLS_OPTIONS,
68+
SystemCertsPatcher,
69+
client_knobs,
70+
db_pwd,
71+
db_user,
72+
global_knobs,
73+
host,
74+
is_server_resolvable,
75+
port,
76+
print_running_topology,
77+
print_thread_stacks,
78+
print_thread_tracebacks,
79+
sanitize_cmd,
80+
sanitize_reply,
81+
)
82+
from test.version import Version
83+
8184
_IS_SYNC = False
8285

8386

@@ -876,7 +879,41 @@ async def reset_client_context():
876879
await async_client_context._init_client()
877880

878881

879-
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
882+
class AsyncPyMongoTestCase(unittest.TestCase):
883+
if not _IS_SYNC:
884+
# Customize async TestCase to use a single event loop for all tests.
885+
def __init__(self, methodName="runTest"):
886+
super().__init__(methodName)
887+
try:
888+
self.loop = asyncio.get_event_loop()
889+
except RuntimeError:
890+
self.loop = asyncio.new_event_loop()
891+
asyncio.set_event_loop(self.loop)
892+
893+
async def asyncSetUp(self):
894+
pass
895+
896+
async def asyncTearDown(self):
897+
pass
898+
899+
# See IsolatedAsyncioTestCase.addAsyncCleanup.
900+
def addAsyncCleanup(self, func, /, *args, **kwargs):
901+
self.addCleanup(*(func, *args), **kwargs)
902+
903+
def run(self, result=None):
904+
if result is None:
905+
result = self.defaultTestResult()
906+
result.startTest(self)
907+
908+
try:
909+
self.setUp()
910+
self.loop.run_until_complete(self.asyncSetUp())
911+
self.loop.run_until_complete(getattr(self, self._testMethodName)())
912+
self.loop.run_until_complete(self.asyncTearDown())
913+
finally:
914+
self.tearDown()
915+
result.stopTest(self)
916+
880917
def assertEqualCommand(self, expected, actual, msg=None):
881918
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
882919

0 commit comments

Comments
 (0)