From e455fa60ed5f3a26e5a4aa25e74f2c29f367b9ee Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Tue, 28 Jan 2025 12:13:14 -0800 Subject: [PATCH 1/8] Use one loop for asyncio test suite --- test/__init__.py | 83 ++++++++++++++++++++++++---------- test/asynchronous/__init__.py | 85 +++++++++++++++++++++++++---------- 2 files changed, 121 insertions(+), 47 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index d3a63db2d5..d2e51572a7 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -30,28 +30,6 @@ import unittest import warnings from asyncio import iscoroutinefunction -from test.helpers import ( - COMPRESSORS, - IS_SRV, - MONGODB_API_VERSION, - MULTI_MONGOS_LB_URI, - TEST_LOADBALANCER, - TEST_SERVERLESS, - TLS_OPTIONS, - SystemCertsPatcher, - client_knobs, - db_pwd, - db_user, - global_knobs, - host, - is_server_resolvable, - port, - print_running_topology, - print_thread_stacks, - print_thread_tracebacks, - sanitize_cmd, - sanitize_reply, -) from pymongo.uri_parser import parse_uri @@ -63,7 +41,6 @@ HAVE_IPADDRESS = False from contextlib import contextmanager from functools import partial, wraps -from test.version import Version from typing import Any, Callable, Dict, Generator, overload from unittest import SkipTest from urllib.parse import quote_plus @@ -78,6 +55,32 @@ from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient +sys.path[0:0] = [""] + +from test.helpers import ( + COMPRESSORS, + IS_SRV, + MONGODB_API_VERSION, + MULTI_MONGOS_LB_URI, + TEST_LOADBALANCER, + TEST_SERVERLESS, + TLS_OPTIONS, + SystemCertsPatcher, + client_knobs, + db_pwd, + db_user, + global_knobs, + host, + is_server_resolvable, + port, + print_running_topology, + print_thread_stacks, + print_thread_tracebacks, + sanitize_cmd, + sanitize_reply, +) +from test.version import Version + _IS_SYNC = True @@ -875,6 +878,40 @@ def reset_client_context(): class PyMongoTestCase(unittest.TestCase): + if not _IS_SYNC: + # Customize async TestCase to use a single event loop for all tests. + def __init__(self, methodName="runTest"): + super().__init__(methodName) + try: + self.loop = asyncio.get_event_loop() + except RuntimeError: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def setUp(self): + pass + + def tearDown(self): + pass + + # See TestCase.addCleanup. + def addCleanup(self, func, /, *args, **kwargs): + self.addCleanup(*(func, *args), **kwargs) + + def run(self, result=None): + if result is None: + result = self.defaultTestResult() + result.startTest(self) + + try: + self.setUp() + self.loop.run_until_complete(self.setUp()) + self.loop.run_until_complete(getattr(self, self._testMethodName)()) + self.loop.run_until_complete(self.tearDown()) + finally: + self.tearDown() + result.stopTest(self) + def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 73e2824742..52ed7df243 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -30,28 +30,6 @@ import unittest import warnings from asyncio import iscoroutinefunction -from test.helpers import ( - COMPRESSORS, - IS_SRV, - MONGODB_API_VERSION, - MULTI_MONGOS_LB_URI, - TEST_LOADBALANCER, - TEST_SERVERLESS, - TLS_OPTIONS, - SystemCertsPatcher, - client_knobs, - db_pwd, - db_user, - global_knobs, - host, - is_server_resolvable, - port, - print_running_topology, - print_thread_stacks, - print_thread_tracebacks, - sanitize_cmd, - sanitize_reply, -) from pymongo.uri_parser import parse_uri @@ -63,7 +41,6 @@ HAVE_IPADDRESS = False from contextlib import asynccontextmanager, contextmanager from functools import partial, wraps -from test.version import Version from typing import Any, Callable, Dict, Generator, overload from unittest import SkipTest from urllib.parse import quote_plus @@ -78,6 +55,32 @@ from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] +sys.path[0:0] = [""] + +from test.helpers import ( + COMPRESSORS, + IS_SRV, + MONGODB_API_VERSION, + MULTI_MONGOS_LB_URI, + TEST_LOADBALANCER, + TEST_SERVERLESS, + TLS_OPTIONS, + SystemCertsPatcher, + client_knobs, + db_pwd, + db_user, + global_knobs, + host, + is_server_resolvable, + port, + print_running_topology, + print_thread_stacks, + print_thread_tracebacks, + sanitize_cmd, + sanitize_reply, +) +from test.version import Version + _IS_SYNC = False @@ -876,7 +879,41 @@ async def reset_client_context(): await async_client_context._init_client() -class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): +class AsyncPyMongoTestCase(unittest.TestCase): + if not _IS_SYNC: + # Customize async TestCase to use a single event loop for all tests. + def __init__(self, methodName="runTest"): + super().__init__(methodName) + try: + self.loop = asyncio.get_event_loop() + except RuntimeError: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + async def asyncSetUp(self): + pass + + async def asyncTearDown(self): + pass + + # See IsolatedAsyncioTestCase.addAsyncCleanup. + def addAsyncCleanup(self, func, /, *args, **kwargs): + self.addCleanup(*(func, *args), **kwargs) + + def run(self, result=None): + if result is None: + result = self.defaultTestResult() + result.startTest(self) + + try: + self.setUp() + self.loop.run_until_complete(self.asyncSetUp()) + self.loop.run_until_complete(getattr(self, self._testMethodName)()) + self.loop.run_until_complete(self.asyncTearDown()) + finally: + self.tearDown() + result.stopTest(self) + def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) From eb75d7763969646ad1a1c0ff52b472447a8b2cb9 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Tue, 28 Jan 2025 12:53:09 -0800 Subject: [PATCH 2/8] PYTHON-5071 Fix AsyncPyMongoTestCase implementation --- test/__init__.py | 40 ++++++++++++++++++++++++----------- test/asynchronous/__init__.py | 40 ++++++++++++++++++++++++----------- 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index d2e51572a7..94c243a10a 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -17,6 +17,7 @@ import asyncio import gc +import inspect import logging import multiprocessing import os @@ -898,19 +899,34 @@ def tearDown(self): def addCleanup(self, func, /, *args, **kwargs): self.addCleanup(*(func, *args), **kwargs) - def run(self, result=None): - if result is None: - result = self.defaultTestResult() - result.startTest(self) + def _callSetUp(self): + self._callAsync(self.setUp) - try: - self.setUp() - self.loop.run_until_complete(self.setUp()) - self.loop.run_until_complete(getattr(self, self._testMethodName)()) - self.loop.run_until_complete(self.tearDown()) - finally: - self.tearDown() - result.stopTest(self) + def _callTestMethod(self, method): + if self._callMaybeAsync(method) is not None: + warnings.warn( + f"It is deprecated to return a value that is not None from a " + f"test case ({method})", + DeprecationWarning, + stacklevel=4, + ) + + def _callTearDown(self): + self._callAsync(self.tearDown) + self.tearDown() + + def _callCleanup(self, function, *args, **kwargs): + self._callMaybeAsync(function, *args, **kwargs) + + def _callAsync(self, func, /, *args, **kwargs): + assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" + return self.loop.run_until_complete(func(*args, **kwargs)) + + def _callMaybeAsync(self, func, /, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return self.loop.run_until_complete(func(*args, **kwargs)) + else: + return func(*args, **kwargs) def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 52ed7df243..f8c55f2216 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -17,6 +17,7 @@ import asyncio import gc +import inspect import logging import multiprocessing import os @@ -900,19 +901,34 @@ async def asyncTearDown(self): def addAsyncCleanup(self, func, /, *args, **kwargs): self.addCleanup(*(func, *args), **kwargs) - def run(self, result=None): - if result is None: - result = self.defaultTestResult() - result.startTest(self) + def _callSetUp(self): + self._callAsync(self.asyncSetUp) - try: - self.setUp() - self.loop.run_until_complete(self.asyncSetUp()) - self.loop.run_until_complete(getattr(self, self._testMethodName)()) - self.loop.run_until_complete(self.asyncTearDown()) - finally: - self.tearDown() - result.stopTest(self) + def _callTestMethod(self, method): + if self._callMaybeAsync(method) is not None: + warnings.warn( + f"It is deprecated to return a value that is not None from a " + f"test case ({method})", + DeprecationWarning, + stacklevel=4, + ) + + def _callTearDown(self): + self._callAsync(self.asyncTearDown) + self.tearDown() + + def _callCleanup(self, function, *args, **kwargs): + self._callMaybeAsync(function, *args, **kwargs) + + def _callAsync(self, func, /, *args, **kwargs): + assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" + return self.loop.run_until_complete(func(*args, **kwargs)) + + def _callMaybeAsync(self, func, /, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return self.loop.run_until_complete(func(*args, **kwargs)) + else: + return func(*args, **kwargs) def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) From c08fcebe3e9b599187457064dc484076633c7c9f Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Tue, 28 Jan 2025 13:18:21 -0800 Subject: [PATCH 3/8] PYTHON-5071 Fix setup --- test/__init__.py | 12 ++++-------- test/asynchronous/__init__.py | 12 ++++-------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index 94c243a10a..e69b78834b 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -880,7 +880,8 @@ def reset_client_context(): class PyMongoTestCase(unittest.TestCase): if not _IS_SYNC: - # Customize async TestCase to use a single event loop for all tests. + # An async TestCase that uses a single event loop for all tests. + # Inspired by TestCase. def __init__(self, methodName="runTest"): super().__init__(methodName) try: @@ -900,16 +901,11 @@ def addCleanup(self, func, /, *args, **kwargs): self.addCleanup(*(func, *args), **kwargs) def _callSetUp(self): + self.setUp() self._callAsync(self.setUp) def _callTestMethod(self, method): - if self._callMaybeAsync(method) is not None: - warnings.warn( - f"It is deprecated to return a value that is not None from a " - f"test case ({method})", - DeprecationWarning, - stacklevel=4, - ) + self._callMaybeAsync(method) def _callTearDown(self): self._callAsync(self.tearDown) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index f8c55f2216..76d5c479fa 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -882,7 +882,8 @@ async def reset_client_context(): class AsyncPyMongoTestCase(unittest.TestCase): if not _IS_SYNC: - # Customize async TestCase to use a single event loop for all tests. + # An async TestCase that uses a single event loop for all tests. + # Inspired by IsolatedAsyncioTestCase. def __init__(self, methodName="runTest"): super().__init__(methodName) try: @@ -902,16 +903,11 @@ def addAsyncCleanup(self, func, /, *args, **kwargs): self.addCleanup(*(func, *args), **kwargs) def _callSetUp(self): + self.setUp() self._callAsync(self.asyncSetUp) def _callTestMethod(self, method): - if self._callMaybeAsync(method) is not None: - warnings.warn( - f"It is deprecated to return a value that is not None from a " - f"test case ({method})", - DeprecationWarning, - stacklevel=4, - ) + self._callMaybeAsync(method) def _callTearDown(self): self._callAsync(self.asyncTearDown) From 2ed7c609dcc99edcd90abb8723ebcee946d7d081 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Tue, 28 Jan 2025 13:36:29 -0800 Subject: [PATCH 4/8] PYTHON-5071 Ignore get_event_loop deprecation --- test/__init__.py | 5 ++++- test/asynchronous/__init__.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index e69b78834b..13fbd5e2dd 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -885,7 +885,10 @@ class PyMongoTestCase(unittest.TestCase): def __init__(self, methodName="runTest"): super().__init__(methodName) try: - self.loop = asyncio.get_event_loop() + with warnings.catch_warnings(): + # Ignore DeprecationWarning: There is no current event loop + warnings.simplefilter("ignore", DeprecationWarning) + self.loop = asyncio.get_event_loop() except RuntimeError: self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 76d5c479fa..4cb1c82152 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -887,7 +887,10 @@ class AsyncPyMongoTestCase(unittest.TestCase): def __init__(self, methodName="runTest"): super().__init__(methodName) try: - self.loop = asyncio.get_event_loop() + with warnings.catch_warnings(): + # Ignore DeprecationWarning: There is no current event loop + warnings.simplefilter("ignore", DeprecationWarning) + self.loop = asyncio.get_event_loop() except RuntimeError: self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) From 87cbd36c2b3854109d8804b21d4469a9b55765f6 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Wed, 29 Jan 2025 11:04:45 -0800 Subject: [PATCH 5/8] PYTHON-5071 Stop reinitializing client_context --- test/__init__.py | 26 +++++++++++--------------- test/asynchronous/__init__.py | 26 +++++++++++--------------- 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index 13fbd5e2dd..65b96fc895 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -868,30 +868,28 @@ def max_message_size_bytes(self): client_context = ClientContext() -def reset_client_context(): - if _IS_SYNC: - # sync tests don't need to reset a client context - return - elif client_context.client is not None: - client_context.client.close() - client_context.client = None - client_context._init_client() - - class PyMongoTestCase(unittest.TestCase): if not _IS_SYNC: # An async TestCase that uses a single event loop for all tests. # Inspired by TestCase. def __init__(self, methodName="runTest"): super().__init__(methodName) + self._loop = None + + @property + def loop(self): + if self._loop: + return self._loop try: with warnings.catch_warnings(): # Ignore DeprecationWarning: There is no current event loop warnings.simplefilter("ignore", DeprecationWarning) - self.loop = asyncio.get_event_loop() + loop = asyncio.get_event_loop() except RuntimeError: - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self._loop = loop + return loop def setUp(self): pass @@ -1188,8 +1186,6 @@ class IntegrationTest(PyMongoTestCase): @client_context.require_connection def setUp(self) -> None: - if not _IS_SYNC: - reset_client_context() if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 4cb1c82152..236a823399 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -870,30 +870,28 @@ async def max_message_size_bytes(self): async_client_context = AsyncClientContext() -async def reset_client_context(): - if _IS_SYNC: - # sync tests don't need to reset a client context - return - elif async_client_context.client is not None: - await async_client_context.client.close() - async_client_context.client = None - await async_client_context._init_client() - - class AsyncPyMongoTestCase(unittest.TestCase): if not _IS_SYNC: # An async TestCase that uses a single event loop for all tests. # Inspired by IsolatedAsyncioTestCase. def __init__(self, methodName="runTest"): super().__init__(methodName) + self._loop = None + + @property + def loop(self): + if self._loop: + return self._loop try: with warnings.catch_warnings(): # Ignore DeprecationWarning: There is no current event loop warnings.simplefilter("ignore", DeprecationWarning) - self.loop = asyncio.get_event_loop() + loop = asyncio.get_event_loop() except RuntimeError: - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self._loop = loop + return loop async def asyncSetUp(self): pass @@ -1206,8 +1204,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): @async_client_context.require_connection async def asyncSetUp(self) -> None: - if not _IS_SYNC: - await reset_client_context() if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): From 148482cdf0bb901494cae9202839dbfec6ae6a81 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Wed, 29 Jan 2025 11:29:16 -0800 Subject: [PATCH 6/8] PYTHON-5071 Bind to pytest's event loop so that it gets cleaned up at exit --- test/__init__.py | 47 +++++++++++++++++++---------------- test/asynchronous/__init__.py | 47 +++++++++++++++++++---------------- test/asynchronous/conftest.py | 2 +- test/conftest.py | 2 +- 4 files changed, 54 insertions(+), 44 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index 65b96fc895..c32bbf4663 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -872,25 +872,6 @@ class PyMongoTestCase(unittest.TestCase): if not _IS_SYNC: # An async TestCase that uses a single event loop for all tests. # Inspired by TestCase. - def __init__(self, methodName="runTest"): - super().__init__(methodName) - self._loop = None - - @property - def loop(self): - if self._loop: - return self._loop - try: - with warnings.catch_warnings(): - # Ignore DeprecationWarning: There is no current event loop - warnings.simplefilter("ignore", DeprecationWarning) - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - self._loop = loop - return loop - def setUp(self): pass @@ -917,11 +898,11 @@ def _callCleanup(self, function, *args, **kwargs): def _callAsync(self, func, /, *args, **kwargs): assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" - return self.loop.run_until_complete(func(*args, **kwargs)) + return get_loop().run_until_complete(func(*args, **kwargs)) def _callMaybeAsync(self, func, /, *args, **kwargs): if inspect.iscoroutinefunction(func): - return self.loop.run_until_complete(func(*args, **kwargs)) + return get_loop().run_until_complete(func(*args, **kwargs)) else: return func(*args, **kwargs) @@ -1233,7 +1214,31 @@ def tearDown(self) -> None: super().tearDown() +LOOP = None + + +def get_loop() -> asyncio.AbstractEventLoop: + global LOOP + if LOOP is None: + try: + LOOP = asyncio.get_running_loop() + except RuntimeError: + # no running event loop, fallback to get_event_loop. + try: + # Ignore DeprecationWarning: There is no current event loop + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + LOOP = asyncio.get_event_loop() + except RuntimeError: + LOOP = asyncio.new_event_loop() + asyncio.set_event_loop(LOOP) + return LOOP + + def setup(): + if not _IS_SYNC: + global LOOP + LOOP = asyncio.get_running_loop() client_context.init() warnings.resetwarnings() warnings.simplefilter("always") diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 236a823399..3e0365229f 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -874,25 +874,6 @@ class AsyncPyMongoTestCase(unittest.TestCase): if not _IS_SYNC: # An async TestCase that uses a single event loop for all tests. # Inspired by IsolatedAsyncioTestCase. - def __init__(self, methodName="runTest"): - super().__init__(methodName) - self._loop = None - - @property - def loop(self): - if self._loop: - return self._loop - try: - with warnings.catch_warnings(): - # Ignore DeprecationWarning: There is no current event loop - warnings.simplefilter("ignore", DeprecationWarning) - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - self._loop = loop - return loop - async def asyncSetUp(self): pass @@ -919,11 +900,11 @@ def _callCleanup(self, function, *args, **kwargs): def _callAsync(self, func, /, *args, **kwargs): assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" - return self.loop.run_until_complete(func(*args, **kwargs)) + return get_loop().run_until_complete(func(*args, **kwargs)) def _callMaybeAsync(self, func, /, *args, **kwargs): if inspect.iscoroutinefunction(func): - return self.loop.run_until_complete(func(*args, **kwargs)) + return get_loop().run_until_complete(func(*args, **kwargs)) else: return func(*args, **kwargs) @@ -1251,7 +1232,31 @@ async def asyncTearDown(self) -> None: await super().asyncTearDown() +LOOP = None + + +def get_loop() -> asyncio.AbstractEventLoop: + global LOOP + if LOOP is None: + try: + LOOP = asyncio.get_running_loop() + except RuntimeError: + # no running event loop, fallback to get_event_loop. + try: + # Ignore DeprecationWarning: There is no current event loop + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + LOOP = asyncio.get_event_loop() + except RuntimeError: + LOOP = asyncio.new_event_loop() + asyncio.set_event_loop(LOOP) + return LOOP + + async def async_setup(): + if not _IS_SYNC: + global LOOP + LOOP = asyncio.get_running_loop() await async_client_context.init() warnings.resetwarnings() warnings.simplefilter("always") diff --git a/test/asynchronous/conftest.py b/test/asynchronous/conftest.py index a27a9f213d..e443dff6c0 100644 --- a/test/asynchronous/conftest.py +++ b/test/asynchronous/conftest.py @@ -22,7 +22,7 @@ def event_loop_policy(): return asyncio.get_event_loop_policy() -@pytest_asyncio.fixture(scope="package", autouse=True) +@pytest_asyncio.fixture(scope="session", autouse=True) async def test_setup_and_teardown(): await async_setup() yield diff --git a/test/conftest.py b/test/conftest.py index 91fad28d0a..a3d954c7c3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -20,7 +20,7 @@ def event_loop_policy(): return asyncio.get_event_loop_policy() -@pytest.fixture(scope="package", autouse=True) +@pytest.fixture(scope="session", autouse=True) def test_setup_and_teardown(): setup() yield From 3a2d84496aa035ff91b9114d559dfc0ef357a5cb Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Wed, 29 Jan 2025 12:00:13 -0800 Subject: [PATCH 7/8] PYTHON-5071 Go back to package scope --- test/asynchronous/conftest.py | 2 +- test/conftest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/conftest.py b/test/asynchronous/conftest.py index e443dff6c0..a27a9f213d 100644 --- a/test/asynchronous/conftest.py +++ b/test/asynchronous/conftest.py @@ -22,7 +22,7 @@ def event_loop_policy(): return asyncio.get_event_loop_policy() -@pytest_asyncio.fixture(scope="session", autouse=True) +@pytest_asyncio.fixture(scope="package", autouse=True) async def test_setup_and_teardown(): await async_setup() yield diff --git a/test/conftest.py b/test/conftest.py index a3d954c7c3..91fad28d0a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -20,7 +20,7 @@ def event_loop_policy(): return asyncio.get_event_loop_policy() -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="package", autouse=True) def test_setup_and_teardown(): setup() yield From b9fcba761f2cb0ba0db55a35587478873c5882c5 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Wed, 29 Jan 2025 14:29:18 -0800 Subject: [PATCH 8/8] PYTHON-5071 Cleanup --- test/__init__.py | 48 +++++++++++++++++------------------ test/asynchronous/__init__.py | 48 +++++++++++++++++------------------ 2 files changed, 48 insertions(+), 48 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index c32bbf4663..b49eee99ac 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -867,6 +867,28 @@ def max_message_size_bytes(self): # Reusable client context client_context = ClientContext() +# Global event loop for async tests. +LOOP = None + + +def get_loop() -> asyncio.AbstractEventLoop: + """Get the test suite's global event loop.""" + global LOOP + if LOOP is None: + try: + LOOP = asyncio.get_running_loop() + except RuntimeError: + # no running event loop, fallback to get_event_loop. + try: + # Ignore DeprecationWarning: There is no current event loop + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + LOOP = asyncio.get_event_loop() + except RuntimeError: + LOOP = asyncio.new_event_loop() + asyncio.set_event_loop(LOOP) + return LOOP + class PyMongoTestCase(unittest.TestCase): if not _IS_SYNC: @@ -878,7 +900,6 @@ def setUp(self): def tearDown(self): pass - # See TestCase.addCleanup. def addCleanup(self, func, /, *args, **kwargs): self.addCleanup(*(func, *args), **kwargs) @@ -1214,31 +1235,10 @@ def tearDown(self) -> None: super().tearDown() -LOOP = None - - -def get_loop() -> asyncio.AbstractEventLoop: - global LOOP - if LOOP is None: - try: - LOOP = asyncio.get_running_loop() - except RuntimeError: - # no running event loop, fallback to get_event_loop. - try: - # Ignore DeprecationWarning: There is no current event loop - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - LOOP = asyncio.get_event_loop() - except RuntimeError: - LOOP = asyncio.new_event_loop() - asyncio.set_event_loop(LOOP) - return LOOP - - def setup(): if not _IS_SYNC: - global LOOP - LOOP = asyncio.get_running_loop() + # Set up the event loop. + get_loop() client_context.init() warnings.resetwarnings() warnings.simplefilter("always") diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 3e0365229f..76fae407da 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -869,6 +869,28 @@ async def max_message_size_bytes(self): # Reusable client context async_client_context = AsyncClientContext() +# Global event loop for async tests. +LOOP = None + + +def get_loop() -> asyncio.AbstractEventLoop: + """Get the test suite's global event loop.""" + global LOOP + if LOOP is None: + try: + LOOP = asyncio.get_running_loop() + except RuntimeError: + # no running event loop, fallback to get_event_loop. + try: + # Ignore DeprecationWarning: There is no current event loop + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + LOOP = asyncio.get_event_loop() + except RuntimeError: + LOOP = asyncio.new_event_loop() + asyncio.set_event_loop(LOOP) + return LOOP + class AsyncPyMongoTestCase(unittest.TestCase): if not _IS_SYNC: @@ -880,7 +902,6 @@ async def asyncSetUp(self): async def asyncTearDown(self): pass - # See IsolatedAsyncioTestCase.addAsyncCleanup. def addAsyncCleanup(self, func, /, *args, **kwargs): self.addCleanup(*(func, *args), **kwargs) @@ -1232,31 +1253,10 @@ async def asyncTearDown(self) -> None: await super().asyncTearDown() -LOOP = None - - -def get_loop() -> asyncio.AbstractEventLoop: - global LOOP - if LOOP is None: - try: - LOOP = asyncio.get_running_loop() - except RuntimeError: - # no running event loop, fallback to get_event_loop. - try: - # Ignore DeprecationWarning: There is no current event loop - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - LOOP = asyncio.get_event_loop() - except RuntimeError: - LOOP = asyncio.new_event_loop() - asyncio.set_event_loop(LOOP) - return LOOP - - async def async_setup(): if not _IS_SYNC: - global LOOP - LOOP = asyncio.get_running_loop() + # Set up the event loop. + get_loop() await async_client_context.init() warnings.resetwarnings() warnings.simplefilter("always")