diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index fbe310ad1e..16dad5eac4 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -31,7 +31,7 @@ set -o xtrace AUTH=${AUTH:-noauth} SSL=${SSL:-nossl} TEST_SUITES=${TEST_SUITES:-} -TEST_ARGS="${*:1}" +TEST_ARGS=("${*:1}") export PIP_QUIET=1 # Quiet by default export PIP_PREFER_BINARY=1 # Prefer binary dists by default @@ -206,6 +206,7 @@ if [ -n "$TEST_INDEX_MANAGEMENT" ]; then TEST_SUITES="index_management" fi +# shellcheck disable=SC2128 if [ -n "$TEST_DATA_LAKE" ] && [ -z "$TEST_ARGS" ]; then TEST_SUITES="data_lake" fi @@ -235,7 +236,7 @@ if [ -n "$PERF_TEST" ]; then TEST_SUITES="perf" # PYTHON-4769 Run perf_test.py directly otherwise pytest's test collection negatively # affects the benchmark results. - TEST_ARGS="test/performance/perf_test.py $TEST_ARGS" + TEST_ARGS+=("test/performance/perf_test.py") fi echo "Running $AUTH tests over $SSL with python $(uv python find)" @@ -251,7 +252,7 @@ if [ -n "$COVERAGE" ] && [ "$PYTHON_IMPL" = "CPython" ]; then # Keep in sync with combine-coverage.sh. # coverage >=5 is needed for relative_files=true. UV_ARGS+=("--group coverage") - TEST_ARGS="$TEST_ARGS --cov" + TEST_ARGS+=("--cov") fi if [ -n "$GREEN_FRAMEWORK" ]; then @@ -265,15 +266,37 @@ PIP_QUIET=0 uv run ${UV_ARGS[*]} --with pip pip list if [ -z "$GREEN_FRAMEWORK" ]; then # Use --capture=tee-sys so pytest prints test output inline: # https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html - PYTEST_ARGS="-v --capture=tee-sys --durations=5 $TEST_ARGS" + PYTEST_ARGS=("-v" "--capture=tee-sys" "--durations=5" "${TEST_ARGS[@]}") if [ -n "$TEST_SUITES" ]; then - PYTEST_ARGS="-m $TEST_SUITES $PYTEST_ARGS" + # Workaround until unittest -> pytest conversion is complete + if [[ "$TEST_SUITES" == *"default_async"* ]]; then + ASYNC_PYTEST_ARGS=("-m asyncio" "--junitxml=xunit-results/TEST-asyncresults.xml" "${PYTEST_ARGS[@]}") + else + ASYNC_PYTEST_ARGS=("-m asyncio and $TEST_SUITES" "--junitxml=xunit-results/TEST-asyncresults.xml" "${PYTEST_ARGS[@]}") + fi + PYTEST_ARGS=("-m $TEST_SUITES and not asyncio" "${PYTEST_ARGS[@]}") + else + ASYNC_PYTEST_ARGS=("-m asyncio" "--junitxml=xunit-results/TEST-asyncresults.xml" "${PYTEST_ARGS[@]}") fi + # Workaround until unittest -> pytest conversion is complete + set +o errexit + # shellcheck disable=SC2048 + uv run ${UV_ARGS[*]} pytest "${PYTEST_ARGS[@]}" + exit_code=$? + # shellcheck disable=SC2048 - uv run ${UV_ARGS[*]} pytest $PYTEST_ARGS + uv run ${UV_ARGS[*]} pytest "${ASYNC_PYTEST_ARGS[@]}" + async_exit_code=$? + set -o errexit + if [ $async_exit_code -ne 5 ] && [ $async_exit_code -ne 0 ]; then + exit $async_exit_code + fi + if [ $exit_code -ne 0 ]; then + exit $exit_code + fi else # shellcheck disable=SC2048 - uv run ${UV_ARGS[*]} green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS + uv run ${UV_ARGS[*]} green_framework_test.py $GREEN_FRAMEWORK -v "${TEST_ARGS[@]}" fi # Handle perf test post actions. diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 3760e308a5..b3c5902824 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -94,8 +94,10 @@ jobs: run: | if [[ "${{ matrix.python-version }}" == "3.13t" ]]; then pytest -v --durations=5 --maxfail=10 + pytest -v --durations=5 --maxfail=10 -m asyncio else just test + just test-async fi doctest: diff --git a/justfile b/justfile index 8a076038a4..07148c4046 100644 --- a/justfile +++ b/justfile @@ -62,6 +62,10 @@ lint-manual: test *args="-v --durations=5 --maxfail=10": {{uv_run}} --extra test pytest {{args}} +[group('test')] +test-async *args="-v --durations=5 --maxfail=10 -m asyncio": + {{uv_run}} --extra test pytest {{args}} + [group('test')] test-mockupdb *args: {{uv_run}} -v --extra test --group mockupdb pytest -m mockupdb {{args}} diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 1600e50628..2f549c6f3c 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1418,6 +1418,9 @@ def __next__(self) -> NoReturn: raise TypeError("'AsyncMongoClient' object is not iterable") next = __next__ + if not _IS_SYNC: + anext = next + __anext__ = next async def _server_property(self, attr_name: str) -> Any: """An attribute of the current server's description. diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index a694a58c1e..a199e0ea2d 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1414,6 +1414,9 @@ def __next__(self) -> NoReturn: raise TypeError("'MongoClient' object is not iterable") next = __next__ + if not _IS_SYNC: + next = next + __next__ = next def _server_property(self, attr_name: str) -> Any: """An attribute of the current server's description. diff --git a/pyproject.toml b/pyproject.toml index 69249ee4c6..e3540f1bc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ zstd = ["requirements/zstd.txt"] [tool.pytest.ini_options] minversion = "7" -addopts = ["-ra", "--strict-config", "--strict-markers", "--junitxml=xunit-results/TEST-results.xml", "-m default or default_async"] +addopts = ["-ra", "--strict-config", "--strict-markers", "--junitxml=xunit-results/TEST-results.xml", "-m default or default_async and not asyncio"] testpaths = ["test"] log_cli_level = "INFO" faulthandler_timeout = 1500 @@ -135,6 +135,8 @@ markers = [ "mockupdb: tests that rely on mockupdb", "default: default test suite", "default_async: default async test suite", + "unit: tests that don't require a connection to MongoDB", + "integration: tests that require a connection to MongoDB", ] [tool.mypy] diff --git a/test/__init__.py b/test/__init__.py index d3a63db2d5..7f8b2b5cc8 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -15,9 +15,7 @@ """Synchronous test suite for pymongo, bson, and gridfs.""" from __future__ import annotations -import asyncio import gc -import logging import multiprocessing import os import signal @@ -26,7 +24,6 @@ import sys import threading import time -import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -518,6 +515,12 @@ def require_data_lake(self, func): func=func, ) + @property + def is_not_mmap(self): + if self.is_mongos: + return True + return self.storage_engine != "mmapv1" + def require_no_mmap(self, func): """Run a test only if the server is not using the MMAPv1 storage engine. Only works for standalone and replica sets; tests are @@ -571,6 +574,10 @@ def require_replica_set(self, func): """Run a test only if the client is connected to a replica set.""" return self._require(lambda: self.is_rs, "Not connected to a replica set", func=func) + @property + def secondaries_count(self): + return 0 if not self.client else len(self.client.secondaries) + def require_secondaries_count(self, count): """Run a test only if the client is connected to a replica set that has `count` secondaries. @@ -589,7 +596,7 @@ def supports_secondary_read_pref(self): if self.has_secondaries: return True if self.is_mongos: - shard = self.client.config.shards.find_one()["host"] # type:ignore[index] + shard = (self.client.config.shards.find_one())["host"] # type:ignore[index] num_members = shard.count(",") + 1 return num_members > 1 return False @@ -874,6 +881,18 @@ def reset_client_context(): client_context._init_client() +class PyMongoTestCasePyTest: + @contextmanager + def fail_point(self, client, command_args): + cmd_on = SON([("configureFailPoint", "failCommand")]) + cmd_on.update(command_args) + client.admin.command(cmd_on) + try: + yield + finally: + client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") + + class PyMongoTestCase(unittest.TestCase): def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 73e2824742..631100d766 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -15,9 +15,7 @@ """Asynchronous test suite for pymongo, bson, and gridfs.""" from __future__ import annotations -import asyncio import gc -import logging import multiprocessing import os import signal @@ -26,7 +24,6 @@ import sys import threading import time -import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -520,6 +517,12 @@ def require_data_lake(self, func): func=func, ) + @property + def is_not_mmap(self): + if self.is_mongos: + return True + return self.storage_engine != "mmapv1" + def require_no_mmap(self, func): """Run a test only if the server is not using the MMAPv1 storage engine. Only works for standalone and replica sets; tests are @@ -573,6 +576,10 @@ def require_replica_set(self, func): """Run a test only if the client is connected to a replica set.""" return self._require(lambda: self.is_rs, "Not connected to a replica set", func=func) + @property + async def secondaries_count(self): + return 0 if not self.client else len(await self.client.secondaries) + def require_secondaries_count(self, count): """Run a test only if the client is connected to a replica set that has `count` secondaries. @@ -588,10 +595,10 @@ async def check(): @property async def supports_secondary_read_pref(self): - if self.has_secondaries: + if await self.has_secondaries: return True if self.is_mongos: - shard = await self.client.config.shards.find_one()["host"] # type:ignore[index] + shard = (await self.client.config.shards.find_one())["host"] # type:ignore[index] num_members = shard.count(",") + 1 return num_members > 1 return False @@ -876,6 +883,20 @@ async def reset_client_context(): await async_client_context._init_client() +class AsyncPyMongoTestCasePyTest: + @asynccontextmanager + async def fail_point(self, client, command_args): + cmd_on = SON([("configureFailPoint", "failCommand")]) + cmd_on.update(command_args) + await client.admin.command(cmd_on) + try: + yield + finally: + await client.admin.command( + "configureFailPoint", cmd_on["configureFailPoint"], mode="off" + ) + + class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) diff --git a/test/asynchronous/conftest.py b/test/asynchronous/conftest.py index a27a9f213d..edb25604c5 100644 --- a/test/asynchronous/conftest.py +++ b/test/asynchronous/conftest.py @@ -2,12 +2,24 @@ import asyncio import sys -from test import pytest_conf -from test.asynchronous import async_setup, async_teardown +from test import MONGODB_API_VERSION, db_pwd, db_user, pytest_conf +from test.asynchronous import ( + AsyncClientContext, + _connection_string, + async_setup, + async_teardown, +) +from test.asynchronous.pymongo_mocks import AsyncMockClient +from test.utils import FunctionCallRecorder +from typing import Any, Callable import pytest import pytest_asyncio +import pymongo +from pymongo import AsyncMongoClient +from pymongo.uri_parser import parse_uri + _IS_SYNC = False @@ -22,11 +34,351 @@ def event_loop_policy(): return asyncio.get_event_loop_policy() -@pytest_asyncio.fixture(scope="package", autouse=True) +@pytest_asyncio.fixture(loop_scope="session", scope="session") +async def async_client_context_fixture(): + client = AsyncClientContext() + await client.init() + yield client + if client.client is not None: + if not client.is_data_lake: + await client.client.drop_database("pymongo-pooling-tests") + await client.client.drop_database("pymongo_test") + await client.client.drop_database("pymongo_test1") + await client.client.drop_database("pymongo_test2") + await client.client.drop_database("pymongo_test_mike") + await client.client.drop_database("pymongo_test_bernie") + await client.client.close() + + +@pytest_asyncio.fixture +async def require_integration(async_client_context_fixture): + if not async_client_context_fixture.connected: + pytest.fail("Integration tests require a MongoDB server") + + +@pytest_asyncio.fixture(loop_scope="session", scope="session") +async def test_environment(async_client_context_fixture): + requirements = {} + requirements["SUPPORT_TRANSACTIONS"] = async_client_context_fixture.supports_transactions() + requirements["IS_DATA_LAKE"] = async_client_context_fixture.is_data_lake + requirements["IS_SYNC"] = _IS_SYNC + requirements["IS_SYNC"] = _IS_SYNC + requirements["REQUIRE_API_VERSION"] = MONGODB_API_VERSION + requirements[ + "SUPPORTS_FAILCOMMAND_FAIL_POINT" + ] = async_client_context_fixture.supports_failCommand_fail_point + requirements["IS_NOT_MMAP"] = async_client_context_fixture.is_not_mmap + requirements["SERVER_VERSION"] = async_client_context_fixture.version + requirements["AUTH_ENABLED"] = async_client_context_fixture.auth_enabled + requirements["FIPS_ENABLED"] = async_client_context_fixture.fips_enabled + requirements["IS_RS"] = async_client_context_fixture.is_rs + requirements["MONGOSES"] = len(async_client_context_fixture.mongoses) + requirements["SECONDARIES_COUNT"] = await async_client_context_fixture.secondaries_count + requirements[ + "SECONDARY_READ_PREF" + ] = await async_client_context_fixture.supports_secondary_read_pref + requirements["HAS_IPV6"] = async_client_context_fixture.has_ipv6 + requirements["IS_SERVERLESS"] = async_client_context_fixture.serverless + requirements["IS_LOAD_BALANCER"] = async_client_context_fixture.load_balancer + requirements["TEST_COMMANDS_ENABLED"] = async_client_context_fixture.test_commands_enabled + requirements["IS_TLS"] = async_client_context_fixture.tls + requirements["IS_TLS_CERT"] = async_client_context_fixture.tlsCertificateKeyFile + requirements["SERVER_IS_RESOLVEABLE"] = async_client_context_fixture.server_is_resolvable + requirements["SESSIONS_ENABLED"] = async_client_context_fixture.sessions_enabled + requirements[ + "SUPPORTS_RETRYABLE_WRITES" + ] = async_client_context_fixture.supports_retryable_writes() + yield requirements + + +@pytest_asyncio.fixture +async def require_auth(test_environment): + if not test_environment["AUTH_ENABLED"]: + pytest.skip("Authentication is not enabled on the server") + + +@pytest_asyncio.fixture +async def require_no_fips(test_environment): + if test_environment["FIPS_ENABLED"]: + pytest.skip("Test cannot run on a FIPS-enabled host") + + +@pytest_asyncio.fixture +async def require_no_tls(test_environment): + if test_environment["IS_TLS"]: + pytest.skip("Must be able to connect without TLS") + + +@pytest_asyncio.fixture +async def require_ipv6(test_environment): + if not test_environment["HAS_IPV6"]: + pytest.skip("No IPv6") + + +@pytest_asyncio.fixture +async def require_sync(test_environment): + if not _IS_SYNC: + pytest.skip("This test only works with the synchronous API") + + +@pytest_asyncio.fixture +async def require_no_mongos(test_environment): + if test_environment["MONGOSES"]: + pytest.skip("Must be connected to a mongod, not a mongos") + + +@pytest_asyncio.fixture +async def require_no_replica_set(test_environment): + if test_environment["IS_RS"]: + pytest.skip("Connected to a replica set, not a standalone mongod") + + +@pytest_asyncio.fixture +async def require_replica_set(test_environment): + if not test_environment["IS_RS"]: + pytest.skip("Not connected to a replica set") + + +@pytest_asyncio.fixture +async def require_sdam(test_environment): + if test_environment["IS_SERVERLESS"] or test_environment["IS_LOAD_BALANCER"]: + pytest.skip("loadBalanced and serverless clients do not run SDAM") + + +@pytest_asyncio.fixture +async def require_no_load_balancer(test_environment): + if test_environment["IS_LOAD_BALANCER"]: + pytest.skip("Must not be connected to a load balancer") + + +@pytest_asyncio.fixture +async def require_failCommand_fail_point(test_environment): + if not test_environment["SUPPORTS_FAILCOMMAND_FAIL_POINT"]: + pytest.skip("failCommand fail point must be supported") + + +@pytest_asyncio.fixture(loop_scope="session", scope="session", autouse=True) async def test_setup_and_teardown(): await async_setup() yield await async_teardown() +async def _async_mongo_client( + async_client_context_fixture, host, port, authenticate=True, directConnection=None, **kwargs +): + """Create a new client over SSL/TLS if necessary.""" + host = host or await async_client_context_fixture.host + port = port or await async_client_context_fixture.port + client_options: dict = async_client_context_fixture.default_client_options.copy() + if async_client_context_fixture.replica_set_name and not directConnection: + client_options["replicaSet"] = async_client_context_fixture.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if async_client_context_fixture.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = AsyncMongoClient(uri, port, **client_options) + if client._options.connect: + await client.aconnect() + return client + + +@pytest_asyncio.fixture(loop_scope="session") +async def async_single_client_noauth( + async_client_context_fixture +) -> Callable[..., AsyncMongoClient]: + """Make a direct connection. Don't authenticate.""" + clients = [] + + async def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + client = await _async_mongo_client( + async_client_context_fixture, h, p, authenticate=False, directConnection=True, **kwargs + ) + clients.append(client) + return client + + yield _make_client + for client in clients: + await client.close() + + +@pytest_asyncio.fixture(loop_scope="session") +async def async_single_client(async_client_context_fixture) -> Callable[..., AsyncMongoClient]: + """Make a direct connection, and authenticate if necessary.""" + clients = [] + + async def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + client = await _async_mongo_client( + async_client_context_fixture, h, p, directConnection=True, **kwargs + ) + clients.append(client) + return client + + yield _make_client + for client in clients: + await client.close() + + +@pytest_asyncio.fixture(loop_scope="session") +async def async_rs_client_noauth(async_client_context_fixture) -> Callable[..., AsyncMongoClient]: + """Connect to the replica set. Don't authenticate.""" + clients = [] + + async def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + client = await _async_mongo_client( + async_client_context_fixture, h, p, authenticate=False, **kwargs + ) + clients.append(client) + return client + + yield _make_client + for client in clients: + await client.close() + + +@pytest_asyncio.fixture(loop_scope="session") +async def async_rs_client(async_client_context_fixture) -> Callable[..., AsyncMongoClient]: + """Connect to the replica set and authenticate if necessary.""" + clients = [] + + async def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + client = await _async_mongo_client(async_client_context_fixture, h, p, **kwargs) + clients.append(client) + return client + + yield _make_client + for client in clients: + await client.close() + + +@pytest_asyncio.fixture(loop_scope="session") +async def async_rs_or_single_client_noauth( + async_client_context_fixture +) -> Callable[..., AsyncMongoClient]: + """Connect to the replica set if there is one, otherwise the standalone. + + Like rs_or_single_client, but does not authenticate. + """ + clients = [] + + async def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + client = await _async_mongo_client( + async_client_context_fixture, h, p, authenticate=False, **kwargs + ) + clients.append(client) + return client + + yield _make_client + for client in clients: + await client.close() + + +@pytest_asyncio.fixture(loop_scope="session") +async def async_rs_or_single_client( + async_client_context_fixture +) -> Callable[..., AsyncMongoClient]: + """Connect to the replica set if there is one, otherwise the standalone. + + Authenticates if necessary. + """ + clients = [] + + async def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + client = await _async_mongo_client(async_client_context_fixture, h, p, **kwargs) + clients.append(client) + return client + + yield _make_client + for client in clients: + await client.close() + + +@pytest_asyncio.fixture(loop_scope="session") +async def simple_client() -> Callable[..., AsyncMongoClient]: + clients = [] + + async def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + if not h and not p: + client = AsyncMongoClient(**kwargs) + else: + client = AsyncMongoClient(h, p, **kwargs) + clients.append(client) + return client + + yield _make_client + for client in clients: + await client.close() + + +@pytest.fixture(scope="function") +def patch_resolver(): + from pymongo.srv_resolver import _resolve + + patched_resolver = FunctionCallRecorder(_resolve) + pymongo.srv_resolver._resolve = patched_resolver + yield patched_resolver + pymongo.srv_resolver._resolve = _resolve + + +@pytest_asyncio.fixture(loop_scope="session") +async def async_mock_client(): + clients = [] + + async def _make_client( + standalones, + members, + mongoses, + hello_hosts=None, + arbiters=None, + down_hosts=None, + *args, + **kwargs, + ): + client = await AsyncMockClient.get_async_mock_client( + standalones, members, mongoses, hello_hosts, arbiters, down_hosts, *args, **kwargs + ) + clients.append(client) + return client + + yield _make_client + for client in clients: + await client.close() + + +@pytest_asyncio.fixture(loop_scope="session") +async def remove_all_users_fixture(async_client_context_fixture, request): + db_name = request.param + yield + await async_client_context_fixture.client[db_name].command( + "dropAllUsersFromDatabase", 1, writeConcern={"w": async_client_context_fixture.w} + ) + + +@pytest_asyncio.fixture(loop_scope="session") +async def drop_user_fixture(async_client_context_fixture, request): + db, user = request.param + yield + await async_client_context_fixture.drop_user(db, user) + + +@pytest_asyncio.fixture(loop_scope="session") +async def drop_database_fixture(async_client_context_fixture, request): + db = request.param + yield + await async_client_context_fixture.client.drop_database(db) + + pytest_collection_modifyitems = pytest_conf.pytest_collection_modifyitems diff --git a/test/asynchronous/pymongo_mocks.py b/test/asynchronous/pymongo_mocks.py index ed2395bc98..6f5bedeff6 100644 --- a/test/asynchronous/pymongo_mocks.py +++ b/test/asynchronous/pymongo_mocks.py @@ -166,7 +166,8 @@ async def get_async_mock_client( standalones, members, mongoses, hello_hosts, arbiters, down_hosts, *args, **kwargs ) - await c.aconnect() + if kwargs.get("connect", True): + await c.aconnect() return c def kill_host(self, host): diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 744a170be2..1249db7d20 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -17,7 +17,6 @@ import _thread as thread import asyncio -import base64 import contextlib import copy import datetime @@ -47,24 +46,18 @@ from test.asynchronous import ( HAVE_IPADDRESS, - AsyncIntegrationTest, - AsyncMockClientTest, - AsyncUnitTest, + AsyncPyMongoTestCasePyTest, SkipTest, - async_client_context, client_knobs, connected, db_pwd, db_user, - remove_all_users, - unittest, ) -from test.asynchronous.pymongo_mocks import AsyncMockClient from test.test_binary import BinaryData from test.utils import ( NTHREADS, CMAPListener, - FunctionCallRecorder, + _default_pytest_mark, async_get_pool, async_wait_until, asyncAssertRaisesExactly, @@ -125,22 +118,19 @@ _IS_SYNC = False -class AsyncClientUnitTest(AsyncUnitTest): - """AsyncMongoClient tests that don't require a server.""" +pytestmark = _default_pytest_mark(_IS_SYNC) - client: AsyncMongoClient - async def asyncSetUp(self) -> None: - self.client = await self.async_rs_or_single_client( - connect=False, serverSelectionTimeoutMS=100 - ) - - @pytest.fixture(autouse=True) - def inject_fixtures(self, caplog): - self._caplog = caplog +@pytest.mark.unit +class TestClientUnitTest: + @pytest_asyncio.fixture(loop_scope="session") + async def async_client(self, async_rs_or_single_client) -> AsyncMongoClient: + client = await async_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) + yield client + await client.close() - async def test_keyword_arg_defaults(self): - client = self.simple_client( + async def test_keyword_arg_defaults(self, simple_client): + client = await simple_client( socketTimeoutMS=None, connectTimeoutMS=20000, waitQueueTimeoutMS=None, @@ -156,220 +146,266 @@ async def test_keyword_arg_defaults(self): options = client.options pool_opts = options.pool_options - self.assertEqual(None, pool_opts.socket_timeout) + assert pool_opts.socket_timeout is None # socket.Socket.settimeout takes a float in seconds - self.assertEqual(20.0, pool_opts.connect_timeout) - self.assertEqual(None, pool_opts.wait_queue_timeout) - self.assertEqual(None, pool_opts._ssl_context) - self.assertEqual(None, options.replica_set_name) - self.assertEqual(ReadPreference.PRIMARY, client.read_preference) - self.assertAlmostEqual(12, client.options.server_selection_timeout) - - async def test_connect_timeout(self): - client = self.simple_client(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) + assert 20.0 == pool_opts.connect_timeout + assert pool_opts.wait_queue_timeout is None + assert pool_opts._ssl_context is None + assert options.replica_set_name is None + assert client.read_preference == ReadPreference.PRIMARY + assert pytest.approx(client.options.server_selection_timeout, rel=1e-9) == 12 + + async def test_connect_timeout(self, simple_client): + client = await simple_client(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) pool_opts = client.options.pool_options - self.assertEqual(None, pool_opts.socket_timeout) - self.assertEqual(None, pool_opts.connect_timeout) + assert pool_opts.socket_timeout is None + assert pool_opts.connect_timeout is None - client = self.simple_client(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) + client = await simple_client(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) pool_opts = client.options.pool_options - self.assertEqual(None, pool_opts.socket_timeout) - self.assertEqual(None, pool_opts.connect_timeout) + assert pool_opts.socket_timeout is None + assert pool_opts.connect_timeout is None - client = self.simple_client( + client = await simple_client( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options - self.assertEqual(None, pool_opts.socket_timeout) - self.assertEqual(None, pool_opts.connect_timeout) - - def test_types(self): - self.assertRaises(TypeError, AsyncMongoClient, 1) - self.assertRaises(TypeError, AsyncMongoClient, 1.14) - self.assertRaises(TypeError, AsyncMongoClient, "localhost", "27017") - self.assertRaises(TypeError, AsyncMongoClient, "localhost", 1.14) - self.assertRaises(TypeError, AsyncMongoClient, "localhost", []) - - self.assertRaises(ConfigurationError, AsyncMongoClient, []) - - async def test_max_pool_size_zero(self): - self.simple_client(maxPoolSize=0) - - def test_uri_detection(self): - self.assertRaises(ConfigurationError, AsyncMongoClient, "/foo") - self.assertRaises(ConfigurationError, AsyncMongoClient, "://") - self.assertRaises(ConfigurationError, AsyncMongoClient, "foo/") - - def test_get_db(self): + assert pool_opts.socket_timeout is None + assert pool_opts.connect_timeout is None + + async def test_types(self): + with pytest.raises(TypeError): + AsyncMongoClient(1) # type: ignore[arg-type] + with pytest.raises(TypeError): + AsyncMongoClient(1.14) # type: ignore[arg-type] + with pytest.raises(TypeError): + AsyncMongoClient("localhost", "27017") # type: ignore[arg-type] + with pytest.raises(TypeError): + AsyncMongoClient("localhost", 1.14) # type: ignore[arg-type] + with pytest.raises(TypeError): + AsyncMongoClient("localhost", []) # type: ignore[arg-type] + + with pytest.raises(ConfigurationError): + AsyncMongoClient([]) + + async def test_max_pool_size_zero(self, simple_client): + await simple_client(maxPoolSize=0) + + async def test_uri_detection(self): + with pytest.raises(ConfigurationError): + AsyncMongoClient("/foo") + with pytest.raises(ConfigurationError): + AsyncMongoClient("://") + with pytest.raises(ConfigurationError): + AsyncMongoClient("foo/") + + async def test_get_db(self, async_client): def make_db(base, name): return base[name] - self.assertRaises(InvalidName, make_db, self.client, "") - self.assertRaises(InvalidName, make_db, self.client, "te$t") - self.assertRaises(InvalidName, make_db, self.client, "te.t") - self.assertRaises(InvalidName, make_db, self.client, "te\\t") - self.assertRaises(InvalidName, make_db, self.client, "te/t") - self.assertRaises(InvalidName, make_db, self.client, "te st") - - self.assertTrue(isinstance(self.client.test, AsyncDatabase)) - self.assertEqual(self.client.test, self.client["test"]) - self.assertEqual(self.client.test, AsyncDatabase(self.client, "test")) - - def test_get_database(self): + with pytest.raises(InvalidName): + make_db(async_client, "") + with pytest.raises(InvalidName): + make_db(async_client, "te$t") + with pytest.raises(InvalidName): + make_db(async_client, "te.t") + with pytest.raises(InvalidName): + make_db(async_client, "te\\t") + with pytest.raises(InvalidName): + make_db(async_client, "te/t") + with pytest.raises(InvalidName): + make_db(async_client, "te st") + # Type and equality assertions + assert isinstance(async_client.test, AsyncDatabase) + assert async_client.test == async_client["test"] + assert async_client.test == AsyncDatabase(async_client, "test") + + async def test_get_database(self, async_client): codec_options = CodecOptions(tz_aware=True) write_concern = WriteConcern(w=2, j=True) - db = self.client.get_database("foo", codec_options, ReadPreference.SECONDARY, write_concern) - self.assertEqual("foo", db.name) - self.assertEqual(codec_options, db.codec_options) - self.assertEqual(ReadPreference.SECONDARY, db.read_preference) - self.assertEqual(write_concern, db.write_concern) + db = async_client.get_database( + "foo", codec_options, ReadPreference.SECONDARY, write_concern + ) + assert db.name == "foo" + assert db.codec_options == codec_options + assert db.read_preference == ReadPreference.SECONDARY + assert db.write_concern == write_concern - def test_getattr(self): - self.assertTrue(isinstance(self.client["_does_not_exist"], AsyncDatabase)) + async def test_getattr(self, async_client): + assert isinstance(async_client["_does_not_exist"], AsyncDatabase) - with self.assertRaises(AttributeError) as context: - self.client._does_not_exist + with pytest.raises(AttributeError) as context: + async_client.client._does_not_exist # Message should be: # "AttributeError: AsyncMongoClient has no attribute '_does_not_exist'. To # access the _does_not_exist database, use client['_does_not_exist']". - self.assertIn("has no attribute '_does_not_exist'", str(context.exception)) + assert "has no attribute '_does_not_exist'" in str(context.value) - def test_iteration(self): - client = self.client + async def test_iteration(self, async_client): msg = "'AsyncMongoClient' object is not iterable" - # Iteration fails - with self.assertRaisesRegex(TypeError, msg): - for _ in client: # type: ignore[misc] # error: "None" not callable [misc] + + with pytest.raises(TypeError, match=msg): + for _ in async_client: break + # Index fails - with self.assertRaises(TypeError): - _ = client[0] - # next fails - with self.assertRaisesRegex(TypeError, "'AsyncMongoClient' object is not iterable"): - _ = next(client) - # .next() fails - with self.assertRaisesRegex(TypeError, "'AsyncMongoClient' object is not iterable"): - _ = client.next() - # Do not implement typing.Iterable. - self.assertNotIsInstance(client, Iterable) - - async def test_get_default_database(self): - c = await self.async_rs_or_single_client( + with pytest.raises(TypeError): + _ = async_client[0] + + # 'next' function fails + with pytest.raises(TypeError, match=msg): + _ = await anext(async_client) + + # 'next()' method fails + with pytest.raises(TypeError, match=msg): + _ = await async_client.anext() + + # Do not implement typing.Iterable + assert not isinstance(async_client, Iterable) + + async def test_get_default_database( + self, async_rs_or_single_client, async_client_context_fixture + ): + c = await async_rs_or_single_client( "mongodb://%s:%d/foo" - % (await async_client_context.host, await async_client_context.port), + % (await async_client_context_fixture.host, await async_client_context_fixture.port), connect=False, ) - self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) + assert AsyncDatabase(c, "foo") == c.get_default_database() # Test that default doesn't override the URI value. - self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("bar")) - + assert AsyncDatabase(c, "foo") == c.get_default_database("bar") codec_options = CodecOptions(tz_aware=True) write_concern = WriteConcern(w=2, j=True) db = c.get_default_database(None, codec_options, ReadPreference.SECONDARY, write_concern) - self.assertEqual("foo", db.name) - self.assertEqual(codec_options, db.codec_options) - self.assertEqual(ReadPreference.SECONDARY, db.read_preference) - self.assertEqual(write_concern, db.write_concern) - - c = await self.async_rs_or_single_client( - "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), + assert "foo" == db.name + assert codec_options == db.codec_options + assert ReadPreference.SECONDARY == db.read_preference + assert write_concern == db.write_concern + + c = await async_rs_or_single_client( + "mongodb://%s:%d/" + % (await async_client_context_fixture.host, await async_client_context_fixture.port), connect=False, ) - self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("foo")) + assert AsyncDatabase(c, "foo") == c.get_default_database("foo") - async def test_get_default_database_error(self): + async def test_get_default_database_error( + self, async_rs_or_single_client, async_client_context_fixture + ): # URI with no database. - c = await self.async_rs_or_single_client( - "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), + c = await async_rs_or_single_client( + "mongodb://%s:%d/" + % (await async_client_context_fixture.host, await async_client_context_fixture.port), connect=False, ) - self.assertRaises(ConfigurationError, c.get_default_database) + with pytest.raises(ConfigurationError): + c.get_default_database() - async def test_get_default_database_with_authsource(self): + async def test_get_default_database_with_authsource( + self, async_client_context_fixture, async_rs_or_single_client + ): # Ensure we distinguish database name from authSource. uri = "mongodb://%s:%d/foo?authSource=src" % ( - await async_client_context.host, - await async_client_context.port, + await async_client_context_fixture.host, + await async_client_context_fixture.port, ) - c = await self.async_rs_or_single_client(uri, connect=False) - self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) + c = await async_rs_or_single_client(uri, connect=False) + assert AsyncDatabase(c, "foo") == c.get_default_database() - async def test_get_database_default(self): - c = await self.async_rs_or_single_client( + async def test_get_database_default( + self, async_client_context_fixture, async_rs_or_single_client + ): + c = await async_rs_or_single_client( "mongodb://%s:%d/foo" - % (await async_client_context.host, await async_client_context.port), + % (await async_client_context_fixture.host, await async_client_context_fixture.port), connect=False, ) - self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) + assert AsyncDatabase(c, "foo") == c.get_database() - async def test_get_database_default_error(self): + async def test_get_database_default_error( + self, async_client_context_fixture, async_rs_or_single_client + ): # URI with no database. - c = await self.async_rs_or_single_client( - "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), + c = await async_rs_or_single_client( + "mongodb://%s:%d/" + % (await async_client_context_fixture.host, await async_client_context_fixture.port), connect=False, ) - self.assertRaises(ConfigurationError, c.get_database) + with pytest.raises(ConfigurationError): + c.get_database() - async def test_get_database_default_with_authsource(self): + async def test_get_database_default_with_authsource( + self, async_client_context_fixture, async_rs_or_single_client + ): # Ensure we distinguish database name from authSource. uri = "mongodb://%s:%d/foo?authSource=src" % ( - await async_client_context.host, - await async_client_context.port, + await async_client_context_fixture.host, + await async_client_context_fixture.port, ) - c = await self.async_rs_or_single_client(uri, connect=False) - self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) + c = await async_rs_or_single_client(uri, connect=False) + assert AsyncDatabase(c, "foo") == c.get_database() - async def test_primary_read_pref_with_tags(self): + async def test_primary_read_pref_with_tags(self, async_single_client): # No tags allowed with "primary". - with self.assertRaises(ConfigurationError): - await self.async_single_client("mongodb://host/?readpreferencetags=dc:east") - - with self.assertRaises(ConfigurationError): - await self.async_single_client( + with pytest.raises(ConfigurationError): + async with await async_single_client("mongodb://host/?readpreferencetags=dc:east"): + pass + with pytest.raises(ConfigurationError): + async with await async_single_client( "mongodb://host/?readpreference=primary&readpreferencetags=dc:east" - ) + ): + pass - async def test_read_preference(self): - c = await self.async_rs_or_single_client( + async def test_read_preference(self, async_client_context_fixture, async_rs_or_single_client): + c = await async_rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode ) - self.assertEqual(c.read_preference, ReadPreference.NEAREST) + assert c.read_preference == ReadPreference.NEAREST - async def test_metadata(self): + async def test_metadata(self, simple_client): metadata = copy.deepcopy(_METADATA) if has_c(): metadata["driver"]["name"] = "PyMongo|c|async" else: metadata["driver"]["name"] = "PyMongo|async" metadata["application"] = {"name": "foobar"} - client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") + + client = await simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options - self.assertEqual(options.pool_options.metadata, metadata) - client = self.simple_client("foo", 27017, appname="foobar", connect=False) + assert options.pool_options.metadata == metadata + + client = await simple_client("foo", 27017, appname="foobar", connect=False) options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + assert options.pool_options.metadata == metadata + # No error - self.simple_client(appname="x" * 128) - with self.assertRaises(ValueError): - self.simple_client(appname="x" * 129) - # Bad "driver" options. - self.assertRaises(TypeError, DriverInfo, "Foo", 1, "a") - self.assertRaises(TypeError, DriverInfo, version="1", platform="a") - self.assertRaises(TypeError, DriverInfo) - with self.assertRaises(TypeError): - self.simple_client(driver=1) - with self.assertRaises(TypeError): - self.simple_client(driver="abc") - with self.assertRaises(TypeError): - self.simple_client(driver=("Foo", "1", "a")) - # Test appending to driver info. + await simple_client(appname="x" * 128) + with pytest.raises(ValueError): + await simple_client(appname="x" * 129) + + # Bad "driver" options. + with pytest.raises(TypeError): + DriverInfo("Foo", 1, "a") # type: ignore[arg-type] + with pytest.raises(TypeError): + DriverInfo(version="1", platform="a") # type: ignore[call-arg] + with pytest.raises(TypeError): + DriverInfo() # type: ignore[call-arg] + with pytest.raises(TypeError): + await simple_client(driver=1) + with pytest.raises(TypeError): + await simple_client(driver="abc") + with pytest.raises(TypeError): + await simple_client(driver=("Foo", "1", "a")) + + # Test appending to driver info. if has_c(): metadata["driver"]["name"] = "PyMongo|c|async|FooDriver" else: metadata["driver"]["name"] = "PyMongo|async|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - client = self.simple_client( + + client = await simple_client( "foo", 27017, appname="foobar", @@ -377,9 +413,10 @@ async def test_metadata(self): connect=False, ) options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + assert options.pool_options.metadata == metadata + metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - client = self.simple_client( + client = await simple_client( "foo", 27017, appname="foobar", @@ -387,38 +424,35 @@ async def test_metadata(self): connect=False, ) options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + assert options.pool_options.metadata == metadata + # Test truncating driver info metadata. - client = self.simple_client( + client = await simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, ) options = client.options - self.assertLessEqual( - len(bson.encode(options.pool_options.metadata)), - _MAX_METADATA_SIZE, - ) - client = self.simple_client( + assert len(bson.encode(options.pool_options.metadata)) <= _MAX_METADATA_SIZE + + client = await simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, ) options = client.options - self.assertLessEqual( - len(bson.encode(options.pool_options.metadata)), - _MAX_METADATA_SIZE, - ) + assert len(bson.encode(options.pool_options.metadata)) <= _MAX_METADATA_SIZE - @mock.patch.dict("os.environ", {ENV_VAR_K8S: "1"}) - def test_container_metadata(self): - metadata = copy.deepcopy(_METADATA) - metadata["driver"]["name"] = "PyMongo|async" - metadata["env"] = {} - metadata["env"]["container"] = {"orchestrator": "kubernetes"} - client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") - options = client.options - self.assertEqual(options.pool_options.metadata["env"], metadata["env"]) + async def test_container_metadata(self, simple_client): + with mock.patch("os.environ", {ENV_VAR_K8S: "1"}): + metadata = copy.deepcopy(_METADATA) + metadata["driver"]["name"] = "PyMongo|async" + metadata["env"] = {} + metadata["env"]["container"] = {"orchestrator": "kubernetes"} - async def test_kwargs_codec_options(self): + client = await simple_client("mongodb://foo:27017/?appname=foobar&connect=false") + options = client.options + assert options.pool_options.metadata["env"] == metadata["env"] + + async def test_kwargs_codec_options(self, simple_client): class MyFloatType: def __init__(self, x): self.__x = x @@ -440,7 +474,7 @@ def transform_python(self, value): uuid_representation_label = "javaLegacy" unicode_decode_error_handler = "ignore" tzinfo = utc - c = self.simple_client( + c = await simple_client( document_class=document_class, type_registry=type_registry, tz_aware=tz_aware, @@ -449,18 +483,16 @@ def transform_python(self, value): tzinfo=tzinfo, connect=False, ) - self.assertEqual(c.codec_options.document_class, document_class) - self.assertEqual(c.codec_options.type_registry, type_registry) - self.assertEqual(c.codec_options.tz_aware, tz_aware) - self.assertEqual( - c.codec_options.uuid_representation, - _UUID_REPRESENTATIONS[uuid_representation_label], + assert c.codec_options.document_class == document_class + assert c.codec_options.type_registry == type_registry + assert c.codec_options.tz_aware == tz_aware + assert ( + c.codec_options.uuid_representation == _UUID_REPRESENTATIONS[uuid_representation_label] ) - self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) - self.assertEqual(c.codec_options.tzinfo, tzinfo) + assert c.codec_options.unicode_decode_error_handler == unicode_decode_error_handler + assert c.codec_options.tzinfo == tzinfo - async def test_uri_codec_options(self): - # Ensure codec options are passed in correctly + async def test_uri_codec_options(self, async_client_context_fixture, simple_client): uuid_representation_label = "javaLegacy" unicode_decode_error_handler = "ignore" datetime_conversion = "DATETIME_CLAMP" @@ -469,57 +501,40 @@ async def test_uri_codec_options(self): "%s&unicode_decode_error_handler=%s" "&datetime_conversion=%s" % ( - await async_client_context.host, - await async_client_context.port, + await async_client_context_fixture.host, + await async_client_context_fixture.port, uuid_representation_label, unicode_decode_error_handler, datetime_conversion, ) ) - c = self.simple_client(uri, connect=False) - self.assertEqual(c.codec_options.tz_aware, True) - self.assertEqual( - c.codec_options.uuid_representation, - _UUID_REPRESENTATIONS[uuid_representation_label], + c = await simple_client(uri, connect=False) + assert c.codec_options.tz_aware is True + assert ( + c.codec_options.uuid_representation == _UUID_REPRESENTATIONS[uuid_representation_label] ) - self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) - self.assertEqual( - c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] - ) - + assert c.codec_options.unicode_decode_error_handler == unicode_decode_error_handler + assert c.codec_options.datetime_conversion == DatetimeConversion[datetime_conversion] # Change the passed datetime_conversion to a number and re-assert. uri = uri.replace(datetime_conversion, f"{int(DatetimeConversion[datetime_conversion])}") - c = self.simple_client(uri, connect=False) - self.assertEqual( - c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] - ) + c = await simple_client(uri, connect=False) + assert c.codec_options.datetime_conversion == DatetimeConversion[datetime_conversion] - async def test_uri_option_precedence(self): + async def test_uri_option_precedence(self, simple_client): # Ensure kwarg options override connection string options. uri = "mongodb://localhost/?ssl=true&replicaSet=name&readPreference=primary" - c = self.simple_client( + c = await simple_client( uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred" ) clopts = c.options opts = clopts._options + assert opts["tls"] is False + assert clopts.replica_set_name == "newname" + assert clopts.read_preference == ReadPreference.SECONDARY_PREFERRED - self.assertEqual(opts["tls"], False) - self.assertEqual(clopts.replica_set_name, "newname") - self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) - - async def test_connection_timeout_ms_propagates_to_DNS_resolver(self): - # Patch the resolver. - from pymongo.srv_resolver import _resolve - - patched_resolver = FunctionCallRecorder(_resolve) - pymongo.srv_resolver._resolve = patched_resolver - - def reset_resolver(): - pymongo.srv_resolver._resolve = _resolve - - self.addCleanup(reset_resolver) - - # Setup. + async def test_connection_timeout_ms_propagates_to_DNS_resolver( + self, patch_resolver, simple_client + ): base_uri = "mongodb+srv://test5.test.build.10gen.cc" connectTimeoutMS = 5000 expected_kw_value = 5.0 @@ -527,10 +542,10 @@ def reset_resolver(): expected_uri_value = 6.0 async def test_scenario(args, kwargs, expected_value): - patched_resolver.reset() - self.simple_client(*args, **kwargs) - for _, kw in patched_resolver.call_list(): - self.assertAlmostEqual(kw["lifetime"], expected_value) + patch_resolver.reset() + await simple_client(*args, **kwargs) + for _, kw in patch_resolver.call_list(): + assert pytest.approx(kw["lifetime"], rel=1e-6) == expected_value # No timeout specified. await test_scenario((base_uri,), {}, CONNECT_TIMEOUT) @@ -545,38 +560,38 @@ async def test_scenario(args, kwargs, expected_value): # Timeout specified in both kwargs and connection string. await test_scenario((uri_with_timeout,), kwarg, expected_kw_value) - async def test_uri_security_options(self): + async def test_uri_security_options(self, simple_client): # Ensure that we don't silently override security-related options. - with self.assertRaises(InvalidURI): - self.simple_client("mongodb://localhost/?ssl=true", tls=False, connect=False) + with pytest.raises(InvalidURI): + await simple_client("mongodb://localhost/?ssl=true", tls=False, connect=False) # Matching SSL and TLS options should not cause errors. - c = self.simple_client("mongodb://localhost/?ssl=false", tls=False, connect=False) - self.assertEqual(c.options._options["tls"], False) + c = await simple_client("mongodb://localhost/?ssl=false", tls=False, connect=False) + assert c.options._options["tls"] is False # Conflicting tlsInsecure options should raise an error. - with self.assertRaises(InvalidURI): - self.simple_client( + with pytest.raises(InvalidURI): + await simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidHostnames=True, ) # Conflicting legacy tlsInsecure options should also raise an error. - with self.assertRaises(InvalidURI): - self.simple_client( + with pytest.raises(InvalidURI): + await simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidCertificates=False, ) # Conflicting kwargs should raise InvalidURI - with self.assertRaises(InvalidURI): - self.simple_client(ssl=True, tls=False) + with pytest.raises(InvalidURI): + await simple_client(ssl=True, tls=False) - async def test_event_listeners(self): - c = self.simple_client(event_listeners=[], connect=False) - self.assertEqual(c.options.event_listeners, []) + async def test_event_listeners(self, simple_client): + c = await simple_client(event_listeners=[], connect=False) + assert c.options.event_listeners == [] listeners = [ event_loggers.CommandLogger(), event_loggers.HeartbeatLogger(), @@ -584,28 +599,30 @@ async def test_event_listeners(self): event_loggers.TopologyLogger(), event_loggers.ConnectionPoolLogger(), ] - c = self.simple_client(event_listeners=listeners, connect=False) - self.assertEqual(c.options.event_listeners, listeners) - - async def test_client_options(self): - c = self.simple_client(connect=False) - self.assertIsInstance(c.options, ClientOptions) - self.assertIsInstance(c.options.pool_options, PoolOptions) - self.assertEqual(c.options.server_selection_timeout, 30) - self.assertEqual(c.options.pool_options.max_idle_time_seconds, None) - self.assertIsInstance(c.options.retry_writes, bool) - self.assertIsInstance(c.options.retry_reads, bool) - - def test_validate_suggestion(self): + c = await simple_client(event_listeners=listeners, connect=False) + assert c.options.event_listeners == listeners + + async def test_client_options(self, simple_client): + c = await simple_client(connect=False) + assert isinstance(c.options, ClientOptions) + assert isinstance(c.options.pool_options, PoolOptions) + assert c.options.server_selection_timeout == 30 + assert c.options.pool_options.max_idle_time_seconds is None + assert isinstance(c.options.retry_writes, bool) + assert isinstance(c.options.retry_reads, bool) + + async def test_validate_suggestion(self): """Validate kwargs in constructor.""" for typo in ["auth", "Auth", "AUTH"]: - expected = f"Unknown option: {typo}. Did you mean one of (authsource, authmechanism, authoidcallowedhosts) or maybe a camelCase version of one? Refer to docstring." + expected = ( + f"Unknown option: {typo}. Did you mean one of (authsource, authmechanism, " + f"authoidcallowedhosts) or maybe a camelCase version of one? Refer to docstring." + ) expected = re.escape(expected) - with self.assertRaisesRegex(ConfigurationError, expected): + with pytest.raises(ConfigurationError, match=expected): AsyncMongoClient(**{typo: "standard"}) # type: ignore[arg-type] - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") - def test_detected_environment_logging(self, mock_get_hosts): + async def test_detected_environment_logging(self, caplog): normal_hosts = [ "normal.host.com", "host.cosmos.azure.com", @@ -616,42 +633,47 @@ def test_detected_environment_logging(self, mock_get_hosts): multi_host = ( "host.cosmos.azure.com,host.docdb.amazonaws.com,host.docdb-elastic.amazonaws.com" ) - with self.assertLogs("pymongo", level="INFO") as cm: - for host in normal_hosts: - AsyncMongoClient(host, connect=False) - for host in srv_hosts: - mock_get_hosts.return_value = [(host, 1)] - AsyncMongoClient(host, connect=False) - AsyncMongoClient(multi_host, connect=False) - logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"] - self.assertEqual(len(logs), 7) - - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") - async def test_detected_environment_warning(self, mock_get_hosts): - with self._caplog.at_level(logging.WARN): - normal_hosts = [ - "host.cosmos.azure.com", - "host.docdb.amazonaws.com", - "host.docdb-elastic.amazonaws.com", - ] - srv_hosts = ["mongodb+srv://:@" + s for s in normal_hosts] - multi_host = ( - "host.cosmos.azure.com,host.docdb.amazonaws.com,host.docdb-elastic.amazonaws.com" - ) - for host in normal_hosts: - with self.assertWarns(UserWarning): - self.simple_client(host) - for host in srv_hosts: - mock_get_hosts.return_value = [(host, 1)] - with self.assertWarns(UserWarning): - self.simple_client(host) - with self.assertWarns(UserWarning): - self.simple_client(multi_host) - - -class TestClient(AsyncIntegrationTest): - def test_multiple_uris(self): - with self.assertRaises(ConfigurationError): + with caplog.at_level(logging.INFO, logger="pymongo"): + with mock.patch("pymongo.srv_resolver._SrvResolver.get_hosts") as mock_get_hosts: + for host in normal_hosts: + AsyncMongoClient(host, connect=False) + for host in srv_hosts: + mock_get_hosts.return_value = [(host, 1)] + AsyncMongoClient(host, connect=False) + AsyncMongoClient(multi_host, connect=False) + logs = [ + record.getMessage() + for record in caplog.records + if record.name == "pymongo.client" + ] + assert len(logs) == 7 + + async def test_detected_environment_warning(self, caplog, simple_client): + normal_hosts = [ + "host.cosmos.azure.com", + "host.docdb.amazonaws.com", + "host.docdb-elastic.amazonaws.com", + ] + srv_hosts = ["mongodb+srv://:@" + s for s in normal_hosts] + multi_host = ( + "host.cosmos.azure.com,host.docdb.amazonaws.com,host.docdb-elastic.amazonaws.com" + ) + with caplog.at_level(logging.WARN, logger="pymongo"): + with mock.patch("pymongo.srv_resolver._SrvResolver.get_hosts") as mock_get_hosts: + with pytest.warns(UserWarning): + for host in normal_hosts: + await simple_client(host) + for host in srv_hosts: + mock_get_hosts.return_value = [(host, 1)] + await simple_client(host) + await simple_client(multi_host) + + +@pytest.mark.usefixtures("require_integration") +@pytest.mark.integration +class TestClientIntegrationTest(AsyncPyMongoTestCasePyTest): + async def test_multiple_uris(self): + with pytest.raises(ConfigurationError): AsyncMongoClient( host=[ "mongodb+srv://cluster-a.abc12.mongodb.net", @@ -660,22 +682,22 @@ def test_multiple_uris(self): ] ) - async def test_max_idle_time_reaper_default(self): + async def test_max_idle_time_reaper_default(self, async_rs_or_single_client): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper doesn't remove connections when maxIdleTimeMS not set - client = await self.async_rs_or_single_client() + client = await async_rs_or_single_client() server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) async with server._pool.checkout() as conn: pass - self.assertEqual(1, len(server._pool.conns)) - self.assertTrue(conn in server._pool.conns) + assert 1 == len(server._pool.conns) + assert conn in server._pool.conns - async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): + async def test_max_idle_time_reaper_removes_stale_minPoolSize(self, async_rs_or_single_client): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper removes idle socket and replaces it with a new one - client = await self.async_rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) + client = await async_rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -683,14 +705,16 @@ async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): pass # When the reaper runs at the same time as the get_socket, two # connections could be created and checked into the pool. - self.assertGreaterEqual(len(server._pool.conns), 1) + assert len(server._pool.conns) >= 1 await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket") await async_wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") - async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): + async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize( + self, async_rs_or_single_client + ): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper respects maxPoolSize when adding new connections. - client = await self.async_rs_or_single_client( + client = await async_rs_or_single_client( maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1 ) server = await (await client._get_topology()).select_server( @@ -700,39 +724,39 @@ async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): pass # When the reaper runs at the same time as the get_socket, # maxPoolSize=1 should prevent two connections from being created. - self.assertEqual(1, len(server._pool.conns)) + assert 1 == len(server._pool.conns) await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket") await async_wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") - async def test_max_idle_time_reaper_removes_stale(self): + async def test_max_idle_time_reaper_removes_stale(self, async_rs_or_single_client): with client_knobs(kill_cursor_frequency=0.1): - # Assert reaper has removed idle socket and NOT replaced it - client = await self.async_rs_or_single_client(maxIdleTimeMS=500) + # Assert that the reaper has removed the idle socket and NOT replaced it. + client = await async_rs_or_single_client(maxIdleTimeMS=500) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) async with server._pool.checkout() as conn_one: pass - # Assert that the pool does not close connections prematurely. + # Assert that the pool does not close connections prematurely await asyncio.sleep(0.300) async with server._pool.checkout() as conn_two: pass - self.assertIs(conn_one, conn_two) + assert conn_one is conn_two await async_wait_until( lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) - async def test_min_pool_size(self): + async def test_min_pool_size(self, async_rs_or_single_client): with client_knobs(kill_cursor_frequency=0.1): - client = await self.async_rs_or_single_client() + client = await async_rs_or_single_client() server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) - self.assertEqual(0, len(server._pool.conns)) + assert len(server._pool.conns) == 0 # Assert that pool started up at minPoolSize - client = await self.async_rs_or_single_client(minPoolSize=10) + client = await async_rs_or_single_client(minPoolSize=10) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -740,145 +764,154 @@ async def test_min_pool_size(self): lambda: len(server._pool.conns) == 10, "pool initialized with 10 connections", ) - - # Assert that if a socket is closed, a new one takes its place + # Assert that if a socket is closed, a new one takes its place. async with server._pool.checkout() as conn: conn.close_conn(None) await async_wait_until( lambda: len(server._pool.conns) == 10, "a closed socket gets replaced from the pool", ) - self.assertFalse(conn in server._pool.conns) + assert conn not in server._pool.conns - async def test_max_idle_time_checkout(self): + async def test_max_idle_time_checkout(self, async_rs_or_single_client): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): - client = await self.async_rs_or_single_client(maxIdleTimeMS=500) + client = await async_rs_or_single_client(maxIdleTimeMS=500) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) async with server._pool.checkout() as conn: pass - self.assertEqual(1, len(server._pool.conns)) + assert len(server._pool.conns) == 1 await asyncio.sleep(1) # Sleep so that the socket becomes stale. - async with server._pool.checkout() as new_con: - self.assertNotEqual(conn, new_con) - self.assertEqual(1, len(server._pool.conns)) - self.assertFalse(conn in server._pool.conns) - self.assertTrue(new_con in server._pool.conns) + async with server._pool.checkout() as new_conn: + assert conn != new_conn + assert len(server._pool.conns) == 1 + assert conn not in server._pool.conns + assert new_conn in server._pool.conns # Test that connections are reused if maxIdleTimeMS is not set. - client = await self.async_rs_or_single_client() + client = await async_rs_or_single_client() server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) async with server._pool.checkout() as conn: pass - self.assertEqual(1, len(server._pool.conns)) + assert len(server._pool.conns) == 1 await asyncio.sleep(1) - async with server._pool.checkout() as new_con: - self.assertEqual(conn, new_con) - self.assertEqual(1, len(server._pool.conns)) + async with server._pool.checkout() as new_conn: + assert conn == new_conn + assert len(server._pool.conns) == 1 - async def test_constants(self): + async def test_constants(self, async_client_context_fixture, simple_client): """This test uses AsyncMongoClient explicitly to make sure that host and port are not overloaded. """ - host, port = await async_client_context.host, await async_client_context.port - kwargs: dict = async_client_context.default_client_options.copy() - if async_client_context.auth_enabled: + host, port = ( + await async_client_context_fixture.host, + await async_client_context_fixture.port, + ) + kwargs: dict = async_client_context_fixture.default_client_options.copy() + if async_client_context_fixture.auth_enabled: kwargs["username"] = db_user kwargs["password"] = db_pwd # Set bad defaults. AsyncMongoClient.HOST = "somedomainthatdoesntexist.org" AsyncMongoClient.PORT = 123456789 - with self.assertRaises(AutoReconnect): - c = self.simple_client(serverSelectionTimeoutMS=10, **kwargs) + with pytest.raises(AutoReconnect): + c = await simple_client(serverSelectionTimeoutMS=10, **kwargs) await connected(c) - - c = self.simple_client(host, port, **kwargs) + c = await simple_client(host, port, **kwargs) # Override the defaults. No error. await connected(c) - # Set good defaults. AsyncMongoClient.HOST = host AsyncMongoClient.PORT = port - # No error. - c = self.simple_client(**kwargs) + c = await simple_client(**kwargs) await connected(c) - async def test_init_disconnected(self): - host, port = await async_client_context.host, await async_client_context.port - c = await self.async_rs_or_single_client(connect=False) + async def test_init_disconnected( + self, async_client_context_fixture, async_rs_or_single_client, simple_client + ): + host, port = ( + await async_client_context_fixture.host, + await async_client_context_fixture.port, + ) + c = await async_rs_or_single_client(connect=False) # is_primary causes client to block until connected - self.assertIsInstance(await c.is_primary, bool) - c = await self.async_rs_or_single_client(connect=False) - self.assertIsInstance(await c.is_mongos, bool) - c = await self.async_rs_or_single_client(connect=False) - self.assertIsInstance(c.options.pool_options.max_pool_size, int) - self.assertIsInstance(c.nodes, frozenset) - - c = await self.async_rs_or_single_client(connect=False) - self.assertEqual(c.codec_options, CodecOptions()) - c = await self.async_rs_or_single_client(connect=False) - self.assertFalse(await c.primary) - self.assertFalse(await c.secondaries) - c = await self.async_rs_or_single_client(connect=False) - self.assertIsInstance(c.topology_description, TopologyDescription) - self.assertEqual(c.topology_description, c._topology._description) - if async_client_context.is_rs: + assert isinstance(await c.is_primary, bool) + c = await async_rs_or_single_client(connect=False) + assert isinstance(await c.is_mongos, bool) + c = await async_rs_or_single_client(connect=False) + assert isinstance(c.options.pool_options.max_pool_size, int) + assert isinstance(c.nodes, frozenset) + + c = await async_rs_or_single_client(connect=False) + assert c.codec_options == CodecOptions() + c = await async_rs_or_single_client(connect=False) + assert not await c.primary + assert not await c.secondaries + c = await async_rs_or_single_client(connect=False) + assert isinstance(c.topology_description, TopologyDescription) + assert c.topology_description == c._topology._description + if async_client_context_fixture.is_rs: # The primary's host and port are from the replica set config. - self.assertIsNotNone(await c.address) + assert await c.address is not None else: - self.assertEqual(await c.address, (host, port)) - + assert await c.address == (host, port) bad_host = "somedomainthatdoesntexist.org" - c = self.simple_client(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) - with self.assertRaises(ConnectionFailure): + c = await simple_client(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + with pytest.raises(ConnectionFailure): await c.pymongo_test.test.find_one() - async def test_init_disconnected_with_auth(self): + async def test_init_disconnected_with_auth(self, simple_client): uri = "mongodb://user:pass@somedomainthatdoesntexist" - c = self.simple_client(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) - with self.assertRaises(ConnectionFailure): + c = await simple_client(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + with pytest.raises(ConnectionFailure): await c.pymongo_test.test.find_one() - async def test_equality(self): - seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = await self.async_rs_or_single_client(seed, connect=False) - self.assertEqual(async_client_context.client, c) + async def test_equality( + self, async_client_context_fixture, async_rs_or_single_client, simple_client + ): + seed = "{}:{}".format( + *list(async_client_context_fixture.client._topology_settings.seeds)[0] + ) + c = await async_rs_or_single_client(seed, connect=False) + assert async_client_context_fixture.client == c # Explicitly test inequality - self.assertFalse(async_client_context.client != c) + assert not async_client_context_fixture.client != c - c = await self.async_rs_or_single_client("invalid.com", connect=False) - self.assertNotEqual(async_client_context.client, c) - self.assertTrue(async_client_context.client != c) + c = await async_rs_or_single_client("invalid.com", connect=False) + assert async_client_context_fixture.client != c + assert async_client_context_fixture.client != c - c1 = self.simple_client("a", connect=False) - c2 = self.simple_client("b", connect=False) + c1 = await simple_client("a", connect=False) + c2 = await simple_client("b", connect=False) # Seeds differ: - self.assertNotEqual(c1, c2) + assert c1 != c2 - c1 = self.simple_client(["a", "b", "c"], connect=False) - c2 = self.simple_client(["c", "a", "b"], connect=False) + c1 = await simple_client(["a", "b", "c"], connect=False) + c2 = await simple_client(["c", "a", "b"], connect=False) # Same seeds but out of order still compares equal: - self.assertEqual(c1, c2) - - async def test_hashable(self): - seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = await self.async_rs_or_single_client(seed, connect=False) - self.assertIn(c, {async_client_context.client}) - c = await self.async_rs_or_single_client("invalid.com", connect=False) - self.assertNotIn(c, {async_client_context.client}) - - async def test_host_w_port(self): - with self.assertRaises(ValueError): - host = await async_client_context.host + assert c1 == c2 + + async def test_hashable(self, async_client_context_fixture, async_rs_or_single_client): + seed = "{}:{}".format( + *list(async_client_context_fixture.client._topology_settings.seeds)[0] + ) + c = await async_rs_or_single_client(seed, connect=False) + assert c in {async_client_context_fixture.client} + c = await async_rs_or_single_client("invalid.com", connect=False) + assert c not in {async_client_context_fixture.client} + + async def test_host_w_port(self, async_client_context_fixture): + with pytest.raises(ValueError): + host = await async_client_context_fixture.host await connected( AsyncMongoClient( f"{host}:1234567", @@ -887,7 +920,7 @@ async def test_host_w_port(self): ) ) - async def test_repr(self): + async def test_repr(self, simple_client): # Used to test 'eval' below. import bson @@ -897,19 +930,16 @@ async def test_repr(self): connect=False, document_class=SON, ) - the_repr = repr(client) - self.assertIn("AsyncMongoClient(host=", the_repr) - self.assertIn("document_class=bson.son.SON, tz_aware=False, connect=False, ", the_repr) - self.assertIn("connecttimeoutms=12345", the_repr) - self.assertIn("replicaset='replset'", the_repr) - self.assertIn("w=1", the_repr) - self.assertIn("wtimeoutms=100", the_repr) - + assert "AsyncMongoClient(host=" in the_repr + assert "document_class=bson.son.SON, tz_aware=False, connect=False, " in the_repr + assert "connecttimeoutms=12345" in the_repr + assert "replicaset='replset'" in the_repr + assert "w=1" in the_repr + assert "wtimeoutms=100" in the_repr async with eval(the_repr) as client_two: - self.assertEqual(client_two, client) - - client = self.simple_client( + assert client_two == client + client = await simple_client( "localhost:27017,localhost:27018", replicaSet="replset", connectTimeoutMS=12345, @@ -919,93 +949,104 @@ async def test_repr(self): connect=False, ) the_repr = repr(client) - self.assertIn("AsyncMongoClient(host=", the_repr) - self.assertIn("document_class=dict, tz_aware=False, connect=False, ", the_repr) - self.assertIn("connecttimeoutms=12345", the_repr) - self.assertIn("replicaset='replset'", the_repr) - self.assertIn("sockettimeoutms=None", the_repr) - self.assertIn("w=1", the_repr) - self.assertIn("wtimeoutms=100", the_repr) - + assert "AsyncMongoClient(host=" in the_repr + assert "document_class=dict, tz_aware=False, connect=False, " in the_repr + assert "connecttimeoutms=12345" in the_repr + assert "replicaset='replset'" in the_repr + assert "sockettimeoutms=None" in the_repr + assert "w=1" in the_repr + assert "wtimeoutms=100" in the_repr async with eval(the_repr) as client_two: - self.assertEqual(client_two, client) + assert client_two == client - async def test_getters(self): + async def test_getters(self, async_client_context_fixture): await async_wait_until( - lambda: async_client_context.nodes == self.client.nodes, "find all nodes" + lambda: async_client_context_fixture.nodes == async_client_context_fixture.client.nodes, + "find all nodes", ) - async def test_list_databases(self): - cmd_docs = (await self.client.admin.command("listDatabases"))["databases"] - cursor = await self.client.list_databases() - self.assertIsInstance(cursor, AsyncCommandCursor) + async def test_list_databases(self, async_client_context_fixture, async_rs_or_single_client): + cmd_docs = (await async_client_context_fixture.client.admin.command("listDatabases"))[ + "databases" + ] + cursor = await async_client_context_fixture.client.list_databases() + assert isinstance(cursor, AsyncCommandCursor) helper_docs = await cursor.to_list() - self.assertTrue(len(helper_docs) > 0) - self.assertEqual(len(helper_docs), len(cmd_docs)) + assert len(helper_docs) > 0 + assert len(helper_docs) == len(cmd_docs) # PYTHON-3529 Some fields may change between calls, just compare names. for helper_doc, cmd_doc in zip(helper_docs, cmd_docs): - self.assertIs(type(helper_doc), dict) - self.assertEqual(helper_doc.keys(), cmd_doc.keys()) - client = await self.async_rs_or_single_client(document_class=SON) - async for doc in await client.list_databases(): - self.assertIs(type(doc), dict) - - await self.client.pymongo_test.test.insert_one({}) - cursor = await self.client.list_databases(filter={"name": "admin"}) + assert isinstance(helper_doc, dict) + assert helper_doc.keys() == cmd_doc.keys() + + client_doc = await async_rs_or_single_client(document_class=SON) + async for doc in await client_doc.list_databases(): + assert isinstance(doc, dict) + + await async_client_context_fixture.client.pymongo_test.test.insert_one({}) + cursor = await async_client_context_fixture.client.list_databases(filter={"name": "admin"}) docs = await cursor.to_list() - self.assertEqual(1, len(docs)) - self.assertEqual(docs[0]["name"], "admin") + assert len(docs) == 1 + assert docs[0]["name"] == "admin" - cursor = await self.client.list_databases(nameOnly=True) + cursor = await async_client_context_fixture.client.list_databases(nameOnly=True) async for doc in cursor: - self.assertEqual(["name"], list(doc)) + assert list(doc) == ["name"] - async def test_list_database_names(self): - await self.client.pymongo_test.test.insert_one({"dummy": "object"}) - await self.client.pymongo_test_mike.test.insert_one({"dummy": "object"}) - cmd_docs = (await self.client.admin.command("listDatabases"))["databases"] + async def test_list_database_names(self, async_client_context_fixture): + await async_client_context_fixture.client.pymongo_test.test.insert_one({"dummy": "object"}) + await async_client_context_fixture.client.pymongo_test_mike.test.insert_one( + {"dummy": "object"} + ) + cmd_docs = (await async_client_context_fixture.client.admin.command("listDatabases"))[ + "databases" + ] cmd_names = [doc["name"] for doc in cmd_docs] - db_names = await self.client.list_database_names() - self.assertTrue("pymongo_test" in db_names) - self.assertTrue("pymongo_test_mike" in db_names) - self.assertEqual(db_names, cmd_names) - - async def test_drop_database(self): - with self.assertRaises(TypeError): - await self.client.drop_database(5) # type: ignore[arg-type] - with self.assertRaises(TypeError): - await self.client.drop_database(None) # type: ignore[arg-type] - - await self.client.pymongo_test.test.insert_one({"dummy": "object"}) - await self.client.pymongo_test2.test.insert_one({"dummy": "object"}) - dbs = await self.client.list_database_names() - self.assertIn("pymongo_test", dbs) - self.assertIn("pymongo_test2", dbs) - await self.client.drop_database("pymongo_test") - - if async_client_context.is_rs: - wc_client = await self.async_rs_or_single_client(w=len(async_client_context.nodes) + 1) - with self.assertRaises(WriteConcernError): + db_names = await async_client_context_fixture.client.list_database_names() + assert "pymongo_test" in db_names + assert "pymongo_test_mike" in db_names + assert db_names == cmd_names + + async def test_drop_database(self, async_client_context_fixture, async_rs_or_single_client): + with pytest.raises(TypeError): + await async_client_context_fixture.client.drop_database(5) # type: ignore[arg-type] + with pytest.raises(TypeError): + await async_client_context_fixture.client.drop_database(None) # type: ignore[arg-type] + + await async_client_context_fixture.client.pymongo_test.test.insert_one({"dummy": "object"}) + await async_client_context_fixture.client.pymongo_test2.test.insert_one({"dummy": "object"}) + dbs = await async_client_context_fixture.client.list_database_names() + assert "pymongo_test" in dbs + assert "pymongo_test2" in dbs + await async_client_context_fixture.client.drop_database("pymongo_test") + + if async_client_context_fixture.is_rs: + wc_client = await async_rs_or_single_client( + w=len(async_client_context_fixture.nodes) + 1 + ) + with pytest.raises(WriteConcernError): await wc_client.drop_database("pymongo_test2") - await self.client.drop_database(self.client.pymongo_test2) - dbs = await self.client.list_database_names() - self.assertNotIn("pymongo_test", dbs) - self.assertNotIn("pymongo_test2", dbs) + await async_client_context_fixture.client.drop_database( + async_client_context_fixture.client.pymongo_test2 + ) + dbs = await async_client_context_fixture.client.list_database_names() + assert "pymongo_test" not in dbs + assert "pymongo_test2" not in dbs - async def test_close(self): - test_client = await self.async_rs_or_single_client() + async def test_close(self, async_rs_or_single_client): + test_client = await async_rs_or_single_client() coll = test_client.pymongo_test.bar await test_client.close() - with self.assertRaises(InvalidOperation): + with pytest.raises(InvalidOperation): await coll.count_documents({}) - async def test_close_kills_cursors(self): + async def test_close_kills_cursors(self, async_rs_or_single_client): if sys.platform.startswith("java"): # We can't figure out how to make this test reliable with Jython. raise SkipTest("Can't test with Jython") - test_client = await self.async_rs_or_single_client() + test_client = await async_rs_or_single_client() # Kill any cursors possibly queued up by previous tests. gc.collect() await test_client._process_periodic_tasks() @@ -1017,248 +1058,264 @@ async def test_close_kills_cursors(self): # Open a cursor and leave it open on the server. cursor = coll.find().batch_size(10) - self.assertTrue(bool(await anext(cursor))) - self.assertLess(cursor.retrieved, docs_inserted) + assert bool(await anext(cursor)) + assert cursor.retrieved < docs_inserted # Open a command cursor and leave it open on the server. cursor = await coll.aggregate([], batchSize=10) - self.assertTrue(bool(await anext(cursor))) + assert bool(await anext(cursor)) del cursor # Required for PyPy, Jython and other Python implementations that # don't use reference counting garbage collection. gc.collect() # Close the client and ensure the topology is closed. - self.assertTrue(test_client._topology._opened) + assert test_client._topology._opened await test_client.close() - self.assertFalse(test_client._topology._opened) - test_client = await self.async_rs_or_single_client() + assert not test_client._topology._opened + test_client = await async_rs_or_single_client() # The killCursors task should not need to re-open the topology. await test_client._process_periodic_tasks() - self.assertTrue(test_client._topology._opened) + assert test_client._topology._opened - async def test_close_stops_kill_cursors_thread(self): - client = await self.async_rs_client() + async def test_close_stops_kill_cursors_thread(self, async_rs_client): + client = await async_rs_client() await client.test.test.find_one() - self.assertFalse(client._kill_cursors_executor._stopped) + assert not client._kill_cursors_executor._stopped # Closing the client should stop the thread. await client.close() - self.assertTrue(client._kill_cursors_executor._stopped) + assert client._kill_cursors_executor._stopped # Reusing the closed client should raise an InvalidOperation error. - with self.assertRaises(InvalidOperation): + with pytest.raises(InvalidOperation): await client.admin.command("ping") # Thread is still stopped. - self.assertTrue(client._kill_cursors_executor._stopped) + assert client._kill_cursors_executor._stopped - async def test_uri_connect_option(self): + async def test_uri_connect_option(self, async_rs_client): # Ensure that topology is not opened if connect=False. - client = await self.async_rs_client(connect=False) - self.assertFalse(client._topology._opened) + client = await async_rs_client(connect=False) + assert not client._topology._opened # Ensure kill cursors thread has not been started. if _IS_SYNC: kc_thread = client._kill_cursors_executor._thread - self.assertFalse(kc_thread and kc_thread.is_alive()) + assert not (kc_thread and kc_thread.is_alive()) else: kc_task = client._kill_cursors_executor._task - self.assertFalse(kc_task and not kc_task.done()) + assert not (kc_task and not kc_task.done()) # Using the client should open topology and start the thread. await client.admin.command("ping") - self.assertTrue(client._topology._opened) + assert client._topology._opened if _IS_SYNC: kc_thread = client._kill_cursors_executor._thread - self.assertTrue(kc_thread and kc_thread.is_alive()) + assert kc_thread and kc_thread.is_alive() else: kc_task = client._kill_cursors_executor._task - self.assertTrue(kc_task and not kc_task.done()) + assert kc_task and not kc_task.done() - async def test_close_does_not_open_servers(self): - client = await self.async_rs_client(connect=False) + async def test_close_does_not_open_servers(self, async_rs_client): + client = await async_rs_client(connect=False) topology = client._topology - self.assertEqual(topology._servers, {}) + assert topology._servers == {} await client.close() - self.assertEqual(topology._servers, {}) + assert topology._servers == {} - async def test_close_closes_sockets(self): - client = await self.async_rs_client() + async def test_close_closes_sockets(self, async_rs_client): + client = await async_rs_client() await client.test.test.find_one() topology = client._topology await client.close() for server in topology._servers.values(): - self.assertFalse(server._pool.conns) - self.assertTrue(server._monitor._executor._stopped) - self.assertTrue(server._monitor._rtt_monitor._executor._stopped) - self.assertFalse(server._monitor._pool.conns) - self.assertFalse(server._monitor._rtt_monitor._pool.conns) - - def test_bad_uri(self): - with self.assertRaises(InvalidURI): + assert not server._pool.conns + assert server._monitor._executor._stopped + assert server._monitor._rtt_monitor._executor._stopped + assert not server._monitor._pool.conns + assert not server._monitor._rtt_monitor._pool.conns + + async def test_bad_uri(self): + with pytest.raises(InvalidURI): AsyncMongoClient("http://localhost") - @async_client_context.require_auth - @async_client_context.require_no_fips - async def test_auth_from_uri(self): - host, port = await async_client_context.host, await async_client_context.port - await async_client_context.create_user("admin", "admin", "pass") - self.addAsyncCleanup(async_client_context.drop_user, "admin", "admin") - self.addAsyncCleanup(remove_all_users, self.client.pymongo_test) - - await async_client_context.create_user( + @pytest.mark.usefixtures("require_auth") + @pytest.mark.usefixtures("require_no_fips") + @pytest.mark.parametrize("remove_all_users_fixture", ["pymongo_test"], indirect=True) + @pytest.mark.parametrize("drop_user_fixture", [("admin", "admin")], indirect=True) + async def test_auth_from_uri( + self, + async_client_context_fixture, + async_rs_or_single_client_noauth, + remove_all_users_fixture, + drop_user_fixture, + ): + host, port = ( + await async_client_context_fixture.host, + await async_client_context_fixture.port, + ) + await async_client_context_fixture.create_user("admin", "admin", "pass") + + await async_client_context_fixture.create_user( "pymongo_test", "user", "pass", roles=["userAdmin", "readWrite"] ) - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): await connected( - await self.async_rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port)) + await async_rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port)) ) # No error. await connected( - await self.async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) + await async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) ) # Wrong database. uri = "mongodb://admin:pass@%s:%d/pymongo_test" % (host, port) - with self.assertRaises(OperationFailure): - await connected(await self.async_rs_or_single_client_noauth(uri)) + with pytest.raises(OperationFailure): + await connected(await async_rs_or_single_client_noauth(uri)) # No error. await connected( - await self.async_rs_or_single_client_noauth( + await async_rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) ) ) # Auth with lazy connection. await ( - await self.async_rs_or_single_client_noauth( + await async_rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False ) ).pymongo_test.test.find_one() # Wrong password. - bad_client = await self.async_rs_or_single_client_noauth( + bad_client = await async_rs_or_single_client_noauth( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False ) - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): await bad_client.pymongo_test.test.find_one() - @async_client_context.require_auth - async def test_username_and_password(self): - await async_client_context.create_user("admin", "ad min", "pa/ss") - self.addAsyncCleanup(async_client_context.drop_user, "admin", "ad min") + @pytest.mark.usefixtures("require_auth") + @pytest.mark.parametrize("drop_user_fixture", [("admin", "ad min")], indirect=True) + async def test_username_and_password( + self, async_client_context_fixture, async_rs_or_single_client_noauth, drop_user_fixture + ): + await async_client_context_fixture.create_user("admin", "ad min", "pa/ss") - c = await self.async_rs_or_single_client_noauth(username="ad min", password="pa/ss") + c = await async_rs_or_single_client_noauth(username="ad min", password="pa/ss") # Username and password aren't in strings that will likely be logged. - self.assertNotIn("ad min", repr(c)) - self.assertNotIn("ad min", str(c)) - self.assertNotIn("pa/ss", repr(c)) - self.assertNotIn("pa/ss", str(c)) + assert "ad min" not in repr(c) + assert "ad min" not in str(c) + assert "pa/ss" not in repr(c) + assert "pa/ss" not in str(c) # Auth succeeds. await c.server_info() - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): await ( - await self.async_rs_or_single_client_noauth(username="ad min", password="foo") + await async_rs_or_single_client_noauth(username="ad min", password="foo") ).server_info() - @async_client_context.require_auth - @async_client_context.require_no_fips - async def test_lazy_auth_raises_operation_failure(self): - host = await async_client_context.host - lazy_client = await self.async_rs_or_single_client_noauth( + @pytest.mark.usefixtures("require_auth") + @pytest.mark.usefixtures("require_no_fips") + async def test_lazy_auth_raises_operation_failure( + self, async_client_context_fixture, async_rs_or_single_client_noauth + ): + host = await async_client_context_fixture.host + lazy_client = await async_rs_or_single_client_noauth( f"mongodb://user:wrong@{host}/pymongo_test", connect=False ) await asyncAssertRaisesExactly(OperationFailure, lazy_client.test.collection.find_one) - @async_client_context.require_no_tls - async def test_unix_socket(self): + @pytest.mark.usefixtures("require_no_tls") + async def test_unix_socket( + self, async_client_context_fixture, async_rs_or_single_client, simple_client + ): if not hasattr(socket, "AF_UNIX"): - raise SkipTest("UNIX-sockets are not supported on this system") + pytest.skip("UNIX-sockets are not supported on this system") - mongodb_socket = "/tmp/mongodb-%d.sock" % (await async_client_context.port,) - encoded_socket = "%2Ftmp%2F" + "mongodb-%d.sock" % (await async_client_context.port,) + mongodb_socket = "/tmp/mongodb-%d.sock" % (await async_client_context_fixture.port,) + encoded_socket = "%2Ftmp%2F" + "mongodb-%d.sock" % ( + await async_client_context_fixture.port, + ) if not os.access(mongodb_socket, os.R_OK): - raise SkipTest("Socket file is not accessible") + pytest.skip("Socket file is not accessible") uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. - client = await self.async_rs_or_single_client(uri) + client = await async_rs_or_single_client(uri) await client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = await client.list_database_names() - self.assertTrue("pymongo_test" in dbs) + assert "pymongo_test" in dbs - self.assertTrue(mongodb_socket in repr(client)) + assert mongodb_socket in repr(client) # Confirm it fails with a missing socket. - with self.assertRaises(ConnectionFailure): - c = self.simple_client( + with pytest.raises(ConnectionFailure): + c = await simple_client( "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 ) await connected(c) - async def test_document_class(self): - c = self.client + async def test_document_class(self, async_client_context_fixture, async_rs_or_single_client): + c = async_client_context_fixture.client db = c.pymongo_test await db.test.insert_one({"x": 1}) - self.assertEqual(dict, c.codec_options.document_class) - self.assertTrue(isinstance(await db.test.find_one(), dict)) - self.assertFalse(isinstance(await db.test.find_one(), SON)) + assert dict == c.codec_options.document_class + assert isinstance(await db.test.find_one(), dict) + assert not isinstance(await db.test.find_one(), SON) - c = await self.async_rs_or_single_client(document_class=SON) + c = await async_rs_or_single_client(document_class=SON) db = c.pymongo_test - self.assertEqual(SON, c.codec_options.document_class) - self.assertTrue(isinstance(await db.test.find_one(), SON)) + assert SON == c.codec_options.document_class + assert isinstance(await db.test.find_one(), SON) - async def test_timeouts(self): - client = await self.async_rs_or_single_client( + async def test_timeouts(self, async_rs_or_single_client): + client = await async_rs_or_single_client( connectTimeoutMS=10500, socketTimeoutMS=10500, maxIdleTimeMS=10500, serverSelectionTimeoutMS=10500, ) - self.assertEqual(10.5, (await async_get_pool(client)).opts.connect_timeout) - self.assertEqual(10.5, (await async_get_pool(client)).opts.socket_timeout) - self.assertEqual(10.5, (await async_get_pool(client)).opts.max_idle_time_seconds) - self.assertEqual(10.5, client.options.pool_options.max_idle_time_seconds) - self.assertEqual(10.5, client.options.server_selection_timeout) + assert 10.5 == (await async_get_pool(client)).opts.connect_timeout + assert 10.5 == (await async_get_pool(client)).opts.socket_timeout + assert 10.5 == (await async_get_pool(client)).opts.max_idle_time_seconds + assert 10.5 == client.options.pool_options.max_idle_time_seconds + assert 10.5 == client.options.server_selection_timeout - async def test_socket_timeout_ms_validation(self): - c = await self.async_rs_or_single_client(socketTimeoutMS=10 * 1000) - self.assertEqual(10, (await async_get_pool(c)).opts.socket_timeout) + async def test_socket_timeout_ms_validation(self, async_rs_or_single_client): + c = await async_rs_or_single_client(socketTimeoutMS=10 * 1000) + assert 10 == (await async_get_pool(c)).opts.socket_timeout - c = await connected(await self.async_rs_or_single_client(socketTimeoutMS=None)) - self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) + c = await connected(await async_rs_or_single_client(socketTimeoutMS=None)) + assert (await async_get_pool(c)).opts.socket_timeout is None - c = await connected(await self.async_rs_or_single_client(socketTimeoutMS=0)) - self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) + c = await connected(await async_rs_or_single_client(socketTimeoutMS=0)) + assert (await async_get_pool(c)).opts.socket_timeout is None - with self.assertRaises(ValueError): - async with await self.async_rs_or_single_client(socketTimeoutMS=-1): + with pytest.raises(ValueError): + async with await async_rs_or_single_client(socketTimeoutMS=-1): pass - with self.assertRaises(ValueError): - async with await self.async_rs_or_single_client(socketTimeoutMS=1e10): + with pytest.raises(ValueError): + async with await async_rs_or_single_client(socketTimeoutMS=1e10): pass - with self.assertRaises(ValueError): - async with await self.async_rs_or_single_client(socketTimeoutMS="foo"): + with pytest.raises(ValueError): + async with await async_rs_or_single_client(socketTimeoutMS="foo"): pass - async def test_socket_timeout(self): - no_timeout = self.client + async def test_socket_timeout(self, async_client_context_fixture, async_rs_or_single_client): + no_timeout = async_client_context_fixture.client timeout_sec = 1 - timeout = await self.async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) - self.addAsyncCleanup(timeout.close) + timeout = await async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) await no_timeout.pymongo_test.drop_collection("test") await no_timeout.pymongo_test.test.insert_one({"x": 1}) @@ -1270,24 +1327,22 @@ async def get_x(db): doc = await anext(db.test.find().where(where_func)) return doc["x"] - self.assertEqual(1, await get_x(no_timeout.pymongo_test)) - with self.assertRaises(NetworkTimeout): + assert 1 == await get_x(no_timeout.pymongo_test) + with pytest.raises(NetworkTimeout): await get_x(timeout.pymongo_test) async def test_server_selection_timeout(self): client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False) - self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + pytest.approx(client.options.server_selection_timeout, 0.1) await client.close() client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False) - self.assertAlmostEqual(0, client.options.server_selection_timeout) + pytest.approx(client.options.server_selection_timeout, 0) - self.assertRaises( - ValueError, AsyncMongoClient, serverSelectionTimeoutMS="foo", connect=False - ) - self.assertRaises(ValueError, AsyncMongoClient, serverSelectionTimeoutMS=-1, connect=False) - self.assertRaises( + pytest.raises(ValueError, AsyncMongoClient, serverSelectionTimeoutMS="foo", connect=False) + pytest.raises(ValueError, AsyncMongoClient, serverSelectionTimeoutMS=-1, connect=False) + pytest.raises( ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False ) await client.close() @@ -1295,108 +1350,106 @@ async def test_server_selection_timeout(self): client = AsyncMongoClient( "mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False ) - self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + pytest.approx(client.options.server_selection_timeout, 0.1) await client.close() client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) - self.assertAlmostEqual(0, client.options.server_selection_timeout) + pytest.approx(client.options.server_selection_timeout, 0) await client.close() # Test invalid timeout in URI ignored and set to default. client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) - self.assertAlmostEqual(30, client.options.server_selection_timeout) + pytest.approx(client.options.server_selection_timeout, 30) await client.close() client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) - self.assertAlmostEqual(30, client.options.server_selection_timeout) + pytest.approx(client.options.server_selection_timeout, 30) - async def test_waitQueueTimeoutMS(self): - client = await self.async_rs_or_single_client(waitQueueTimeoutMS=2000) - self.assertEqual((await async_get_pool(client)).opts.wait_queue_timeout, 2) + async def test_waitQueueTimeoutMS(self, async_rs_or_single_client): + client = await async_rs_or_single_client(waitQueueTimeoutMS=2000) + assert 2 == (await async_get_pool(client)).opts.wait_queue_timeout - async def test_socketKeepAlive(self): - pool = await async_get_pool(self.client) + async def test_socketKeepAlive(self, async_client_context_fixture): + pool = await async_get_pool(async_client_context_fixture.client) async with pool.checkout() as conn: keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) - self.assertTrue(keepalive) + assert keepalive @no_type_check - async def test_tz_aware(self): - self.assertRaises(ValueError, AsyncMongoClient, tz_aware="foo") + async def test_tz_aware(self, async_client_context_fixture, async_rs_or_single_client): + pytest.raises(ValueError, AsyncMongoClient, tz_aware="foo") - aware = await self.async_rs_or_single_client(tz_aware=True) - self.addAsyncCleanup(aware.close) - naive = self.client + aware = await async_rs_or_single_client(tz_aware=True) + naive = async_client_context_fixture.client await aware.pymongo_test.drop_collection("test") now = datetime.datetime.now(tz=datetime.timezone.utc) await aware.pymongo_test.test.insert_one({"x": now}) - self.assertEqual(None, (await naive.pymongo_test.test.find_one())["x"].tzinfo) - self.assertEqual(utc, (await aware.pymongo_test.test.find_one())["x"].tzinfo) - self.assertEqual( - (await aware.pymongo_test.test.find_one())["x"].replace(tzinfo=None), - (await naive.pymongo_test.test.find_one())["x"], - ) + assert (await naive.pymongo_test.test.find_one())["x"].tzinfo is None + assert utc == (await aware.pymongo_test.test.find_one())["x"].tzinfo + assert (await aware.pymongo_test.test.find_one())["x"].replace(tzinfo=None) == ( + await naive.pymongo_test.test.find_one() + )["x"] - @async_client_context.require_ipv6 - async def test_ipv6(self): - if async_client_context.tls: + @pytest.mark.usefixtures("require_ipv6") + async def test_ipv6(self, async_client_context_fixture, async_rs_or_single_client_noauth): + if async_client_context_fixture.tls: if not HAVE_IPADDRESS: - raise SkipTest("Need the ipaddress module to test with SSL") + pytest.skip("Need the ipaddress module to test with SSL") - if async_client_context.auth_enabled: + if async_client_context_fixture.auth_enabled: auth_str = f"{db_user}:{db_pwd}@" else: auth_str = "" - uri = "mongodb://%s[::1]:%d" % (auth_str, await async_client_context.port) - if async_client_context.is_rs: - uri += "/?replicaSet=" + (async_client_context.replica_set_name or "") + uri = "mongodb://%s[::1]:%d" % (auth_str, await async_client_context_fixture.port) + if async_client_context_fixture.is_rs: + uri += "/?replicaSet=" + (async_client_context_fixture.replica_set_name or "") - client = await self.async_rs_or_single_client_noauth(uri) + client = await async_rs_or_single_client_noauth(uri) await client.pymongo_test.test.insert_one({"dummy": "object"}) await client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) dbs = await client.list_database_names() - self.assertTrue("pymongo_test" in dbs) - self.assertTrue("pymongo_test_bernie" in dbs) + assert "pymongo_test" in dbs + assert "pymongo_test_bernie" in dbs - async def test_contextlib(self): - client = await self.async_rs_or_single_client() + async def test_contextlib(self, async_rs_or_single_client): + client = await async_rs_or_single_client() await client.pymongo_test.drop_collection("test") await client.pymongo_test.test.insert_one({"foo": "bar"}) # The socket used for the previous commands has been returned to the # pool - self.assertEqual(1, len((await async_get_pool(client)).conns)) + assert 1 == len((await async_get_pool(client)).conns) # contextlib async support was added in Python 3.10 if _IS_SYNC or sys.version_info >= (3, 10): async with contextlib.aclosing(client): - self.assertEqual("bar", (await client.pymongo_test.test.find_one())["foo"]) - with self.assertRaises(InvalidOperation): + assert "bar" == (await client.pymongo_test.test.find_one())["foo"] + with pytest.raises(InvalidOperation): await client.pymongo_test.test.find_one() - client = await self.async_rs_or_single_client() + client = await async_rs_or_single_client() async with client as client: - self.assertEqual("bar", (await client.pymongo_test.test.find_one())["foo"]) - with self.assertRaises(InvalidOperation): + assert "bar" == (await client.pymongo_test.test.find_one())["foo"] + with pytest.raises(InvalidOperation): await client.pymongo_test.test.find_one() - @async_client_context.require_sync - def test_interrupt_signal(self): + @pytest.mark.usefixtures("require_sync") + def test_interrupt_signal(self, async_client_context_fixture): if sys.platform.startswith("java"): # We can't figure out how to raise an exception on a thread that's # blocked on a socket, whether that's the main thread or a worker, # without simply killing the whole thread in Jython. This suggests # PYTHON-294 can't actually occur in Jython. - raise SkipTest("Can't test interrupts in Jython") + pytest.skip("Can't test interrupts in Jython") if is_greenthread_patched(): - raise SkipTest("Can't reliably test interrupts with green threads") + pytest.skip("Can't reliably test interrupts with green threads") # Test fix for PYTHON-294 -- make sure AsyncMongoClient closes its # socket if it gets an interrupt while waiting to recv() from it. - db = self.client.pymongo_test + db = async_client_context_fixture.client.pymongo_test # A $where clause which takes 1.5 sec to execute where = delay(1.5) @@ -1438,48 +1491,48 @@ def sigalarm(num, frame): except KeyboardInterrupt: raised = True - # Can't use self.assertRaises() because it doesn't catch system - # exceptions - self.assertTrue(raised, "Didn't raise expected KeyboardInterrupt") + assert raised, "Didn't raise expected KeyboardInterrupt" # Raises AssertionError due to PYTHON-294 -- Mongo's response to # the previous find() is still waiting to be read on the socket, # so the request id's don't match. - self.assertEqual({"_id": 1}, next(db.foo.find())) # type: ignore[call-overload] + assert {"_id": 1} == next(db.foo.find()) # type: ignore[call-overload] finally: if old_signal_handler: signal.signal(signal.SIGALRM, old_signal_handler) - async def test_operation_failure(self): + async def test_operation_failure(self, async_single_client): # Ensure AsyncMongoClient doesn't close socket after it gets an error # response to getLastError. PYTHON-395. We need a new client here # to avoid race conditions caused by replica set failover or idle # socket reaping. - client = await self.async_single_client() + client = await async_single_client() await client.pymongo_test.test.find_one() pool = await async_get_pool(client) socket_count = len(pool.conns) - self.assertGreaterEqual(socket_count, 1) + assert socket_count >= 1 old_conn = next(iter(pool.conns)) await client.pymongo_test.test.drop() await client.pymongo_test.test.insert_one({"_id": "foo"}) - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): await client.pymongo_test.test.insert_one({"_id": "foo"}) - self.assertEqual(socket_count, len(pool.conns)) - new_con = next(iter(pool.conns)) - self.assertEqual(old_conn, new_con) + assert socket_count == len(pool.conns) + new_conn = next(iter(pool.conns)) + assert old_conn == new_conn - async def test_lazy_connect_w0(self): + @pytest.mark.parametrize("drop_database_fixture", ["test_lazy_connect_w0"], indirect=True) + async def test_lazy_connect_w0( + self, async_client_context_fixture, async_rs_or_single_client, drop_database_fixture + ): # Ensure that connect-on-demand works when the first operation is # an unacknowledged write. This exercises _writable_max_wire_version(). # Use a separate collection to avoid races where we're still # completing an operation on a collection while the next test begins. - await async_client_context.client.drop_database("test_lazy_connect_w0") - self.addAsyncCleanup(async_client_context.client.drop_database, "test_lazy_connect_w0") + await async_client_context_fixture.client.drop_database("test_lazy_connect_w0") - client = await self.async_rs_or_single_client(connect=False, w=0) + client = await async_rs_or_single_client(connect=False, w=0) await client.test_lazy_connect_w0.test.insert_one({}) async def predicate(): @@ -1487,7 +1540,7 @@ async def predicate(): await async_wait_until(predicate, "find one document") - client = await self.async_rs_or_single_client(connect=False, w=0) + client = await async_rs_or_single_client(connect=False, w=0) await client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) async def predicate(): @@ -1495,7 +1548,7 @@ async def predicate(): await async_wait_until(predicate, "update one document") - client = await self.async_rs_or_single_client(connect=False, w=0) + client = await async_rs_or_single_client(connect=False, w=0) await client.test_lazy_connect_w0.test.delete_one({}) async def predicate(): @@ -1503,11 +1556,11 @@ async def predicate(): await async_wait_until(predicate, "delete one document") - @async_client_context.require_no_mongos - async def test_exhaust_network_error(self): + @pytest.mark.usefixtures("require_no_mongos") + async def test_exhaust_network_error(self, async_rs_or_single_client): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = await self.async_rs_or_single_client(maxPoolSize=1, retryReads=False) + client = await async_rs_or_single_client(maxPoolSize=1, retryReads=False) collection = client.pymongo_test.test pool = await async_get_pool(client) pool._check_interval_seconds = None # Never check. @@ -1519,24 +1572,22 @@ async def test_exhaust_network_error(self): conn = one(pool.conns) conn.conn.close() cursor = collection.find(cursor_type=CursorType.EXHAUST) - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): await anext(cursor) - self.assertTrue(conn.closed) + assert conn.closed # The semaphore was decremented despite the error. - self.assertEqual(0, pool.requests) + assert 0 == pool.requests - @async_client_context.require_auth - async def test_auth_network_error(self): + @pytest.mark.usefixtures("require_auth") + async def test_auth_network_error(self, async_rs_or_single_client): # Make sure there's no semaphore leak if we get a network error # when authenticating a new socket with cached credentials. # Get a client with one socket so we detect if it's leaked. c = await connected( - await self.async_rs_or_single_client( - maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False - ) + await async_rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False) ) # Cause a network error on the actual socket. @@ -1546,25 +1597,25 @@ async def test_auth_network_error(self): # AsyncConnection.authenticate logs, but gets a socket.error. Should be # reraised as AutoReconnect. - with self.assertRaises(AutoReconnect): + with pytest.raises(AutoReconnect): await c.test.collection.find_one() # No semaphore leak, the pool is allowed to make a new socket. await c.test.collection.find_one() - @async_client_context.require_no_replica_set - async def test_connect_to_standalone_using_replica_set_name(self): - client = await self.async_single_client(replicaSet="anything", serverSelectionTimeoutMS=100) - with self.assertRaises(AutoReconnect): + @pytest.mark.usefixtures("require_no_replica_set") + async def test_connect_to_standalone_using_replica_set_name(self, async_single_client): + client = await async_single_client(replicaSet="anything", serverSelectionTimeoutMS=100) + with pytest.raises(AutoReconnect): await client.test.test.find_one() - @async_client_context.require_replica_set - async def test_stale_getmore(self): + @pytest.mark.usefixtures("require_replica_set") + async def test_stale_getmore(self, async_rs_client): # A cursor is created, but its member goes down and is removed from # the topology before the getMore message is sent. Test that # AsyncMongoClient._run_operation_with_response handles the error. - with self.assertRaises(AutoReconnect): - client = await self.async_rs_client(connect=False, serverSelectionTimeoutMS=100) + with pytest.raises(AutoReconnect): + client = await async_rs_client(connect=False, serverSelectionTimeoutMS=100) await client._run_operation( operation=message._GetMore( "pymongo_test", @@ -1584,7 +1635,7 @@ async def test_stale_getmore(self): address=("not-a-member", 27017), ) - async def test_heartbeat_frequency_ms(self): + async def test_heartbeat_frequency_ms(self, async_client_context_fixture, async_single_client): class HeartbeatStartedListener(ServerHeartbeatListener): def __init__(self): self.results = [] @@ -1609,116 +1660,129 @@ def init(self, *args): ServerHeartbeatStartedEvent.__init__ = init # type: ignore listener = HeartbeatStartedListener() uri = "mongodb://%s:%d/?heartbeatFrequencyMS=500" % ( - await async_client_context.host, - await async_client_context.port, + await async_client_context_fixture.host, + await async_client_context_fixture.port, ) - await self.async_single_client(uri, event_listeners=[listener]) + await async_single_client(uri, event_listeners=[listener]) await async_wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) # Default heartbeatFrequencyMS is 10 sec. Check the interval was # closer to 0.5 sec with heartbeatFrequencyMS configured. - self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) + pytest.approx(heartbeat_times[1] - heartbeat_times[0], 0.5, abs=2) finally: ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore - def test_small_heartbeat_frequency_ms(self): + async def test_small_heartbeat_frequency_ms(self): uri = "mongodb://example/?heartbeatFrequencyMS=499" - with self.assertRaises(ConfigurationError) as context: + with pytest.raises(ConfigurationError) as context: AsyncMongoClient(uri) - self.assertIn("heartbeatFrequencyMS", str(context.exception)) + assert "heartbeatFrequencyMS" in str(context.value) - async def test_compression(self): + async def test_compression( + self, async_client_context_fixture, simple_client, async_single_client + ): def compression_settings(client): pool_options = client.options.pool_options return pool_options._compression_settings - uri = "mongodb://localhost:27017/?compressors=zlib" - client = self.simple_client(uri, connect=False) + client = await simple_client("mongodb://localhost:27017/?compressors=zlib", connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" - client = self.simple_client(uri, connect=False) + assert opts.compressors == ["zlib"] + + client = await simple_client( + "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, 4) - uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" - client = self.simple_client(uri, connect=False) + assert opts.compressors == ["zlib"] + assert opts.zlib_compression_level == 4 + + client = await simple_client( + "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) - uri = "mongodb://localhost:27017" - client = self.simple_client(uri, connect=False) + assert opts.compressors == ["zlib"] + assert opts.zlib_compression_level == -1 + + client = await simple_client("mongodb://localhost:27017", connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, []) - self.assertEqual(opts.zlib_compression_level, -1) - uri = "mongodb://localhost:27017/?compressors=foobar" - client = self.simple_client(uri, connect=False) + assert opts.compressors == [] + assert opts.zlib_compression_level == -1 + + client = await simple_client("mongodb://localhost:27017/?compressors=foobar", connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, []) - self.assertEqual(opts.zlib_compression_level, -1) - uri = "mongodb://localhost:27017/?compressors=foobar,zlib" - client = self.simple_client(uri, connect=False) + assert opts.compressors == [] + assert opts.zlib_compression_level == -1 + + client = await simple_client( + "mongodb://localhost:27017/?compressors=foobar,zlib", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + assert opts.compressors == ["zlib"] + assert opts.zlib_compression_level == -1 - # According to the connection string spec, unsupported values - # just raise a warning and are ignored. - uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" - client = self.simple_client(uri, connect=False) + client = await simple_client( + "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) - uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" - client = self.simple_client(uri, connect=False) + assert opts.compressors == ["zlib"] + assert opts.zlib_compression_level == -1 + + client = await simple_client( + "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + assert opts.compressors == ["zlib"] + assert opts.zlib_compression_level == -1 if not _have_snappy(): - uri = "mongodb://localhost:27017/?compressors=snappy" - client = self.simple_client(uri, connect=False) + client = await simple_client( + "mongodb://localhost:27017/?compressors=snappy", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, []) + assert opts.compressors == [] else: - uri = "mongodb://localhost:27017/?compressors=snappy" - client = self.simple_client(uri, connect=False) + client = await simple_client( + "mongodb://localhost:27017/?compressors=snappy", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["snappy"]) - uri = "mongodb://localhost:27017/?compressors=snappy,zlib" - client = self.simple_client(uri, connect=False) + assert opts.compressors == ["snappy"] + client = await simple_client( + "mongodb://localhost:27017/?compressors=snappy,zlib", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["snappy", "zlib"]) + assert opts.compressors == ["snappy", "zlib"] if not _have_zstd(): - uri = "mongodb://localhost:27017/?compressors=zstd" - client = self.simple_client(uri, connect=False) + client = await simple_client( + "mongodb://localhost:27017/?compressors=zstd", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, []) + assert opts.compressors == [] else: - uri = "mongodb://localhost:27017/?compressors=zstd" - client = self.simple_client(uri, connect=False) + client = await simple_client( + "mongodb://localhost:27017/?compressors=zstd", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zstd"]) - uri = "mongodb://localhost:27017/?compressors=zstd,zlib" - client = self.simple_client(uri, connect=False) + assert opts.compressors == ["zstd"] + client = await simple_client( + "mongodb://localhost:27017/?compressors=zstd,zlib", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zstd", "zlib"]) + assert opts.compressors == ["zstd", "zlib"] - options = async_client_context.default_client_options + options = async_client_context_fixture.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): - client = await self.async_single_client(zlibcompressionlevel=level) - # No error - await client.pymongo_test.test.find_one() + client = await async_single_client(zlibcompressionlevel=level) + await client.pymongo_test.test.find_one() # No error - @async_client_context.require_sync - async def test_reset_during_update_pool(self): - client = await self.async_rs_or_single_client(minPoolSize=10) + @pytest.mark.usefixtures("require_sync") + async def test_reset_during_update_pool(self, async_rs_or_single_client): + client = await async_rs_or_single_client(minPoolSize=10) await client.admin.command("ping") pool = await async_get_pool(client) generation = pool.gen.get_overall() @@ -1746,8 +1810,7 @@ def run(self): t = ResetPoolThread(pool) t.start() - # Ensure that update_pool completes without error even when the pool - # is reset concurrently. + # Ensure that update_pool completes without error even when the pool is reset concurrently. try: while True: for _ in range(10): @@ -1759,15 +1822,14 @@ def run(self): t.join() await client.admin.command("ping") - async def test_background_connections_do_not_hold_locks(self): + async def test_background_connections_do_not_hold_locks(self, async_rs_or_single_client): min_pool_size = 10 - client = await self.async_rs_or_single_client( + client = await async_rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False ) - # Create a single connection in the pool. - await client.admin.command("ping") + await client.admin.command("ping") # Create a single connection in the pool - # Cause new connections stall for a few seconds. + # Cause new connections to stall for a few seconds. pool = await async_get_pool(client) original_connect = pool.connect @@ -1775,44 +1837,39 @@ async def stall_connect(*args, **kwargs): await asyncio.sleep(2) return await original_connect(*args, **kwargs) - pool.connect = stall_connect - # Un-patch Pool.connect to break the cyclic reference. - self.addCleanup(delattr, pool, "connect") - - # Wait for the background thread to start creating connections - await async_wait_until(lambda: len(pool.conns) > 1, "start creating connections") + try: + pool.connect = stall_connect + + await async_wait_until(lambda: len(pool.conns) > 1, "start creating connections") + # Assert that application operations do not block. + for _ in range(10): + start = time.monotonic() + await client.admin.command("ping") + total = time.monotonic() - start + assert total < 2 + finally: + delattr(pool, "connect") - # Assert that application operations do not block. - for _ in range(10): - start = time.monotonic() - await client.admin.command("ping") - total = time.monotonic() - start - # Each ping command should not take more than 2 seconds - self.assertLess(total, 2) - - @async_client_context.require_replica_set - async def test_direct_connection(self): - # direct_connection=True should result in Single topology. - client = await self.async_rs_or_single_client(directConnection=True) + @pytest.mark.usefixtures("require_replica_set") + async def test_direct_connection(self, async_rs_or_single_client): + client = await async_rs_or_single_client(directConnection=True) await client.admin.command("ping") - self.assertEqual(len(client.nodes), 1) - self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) + assert len(client.nodes) == 1 + assert client._topology_settings.get_topology_type() == TOPOLOGY_TYPE.Single - # direct_connection=False should result in RS topology. - client = await self.async_rs_or_single_client(directConnection=False) + client = await async_rs_or_single_client(directConnection=False) await client.admin.command("ping") - self.assertGreaterEqual(len(client.nodes), 1) - self.assertIn( - client._topology_settings.get_topology_type(), - [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary], - ) + assert len(client.nodes) >= 1 + assert client._topology_settings.get_topology_type() in [ + TOPOLOGY_TYPE.ReplicaSetNoPrimary, + TOPOLOGY_TYPE.ReplicaSetWithPrimary, + ] - # directConnection=True, should error with multiple hosts as a list. - with self.assertRaises(ConfigurationError): + with pytest.raises(ConfigurationError): AsyncMongoClient(["host1", "host2"], directConnection=True) - @unittest.skipIf("PyPy" in sys.version, "PYTHON-2927 fails often on PyPy") - async def test_continuous_network_errors(self): + @pytest.mark.skipif("PyPy" in sys.version, reason="PYTHON-2927 fails often on PyPy") + async def test_continuous_network_errors(self, simple_client): def server_description_count(): i = 0 for obj in gc.get_objects(): @@ -1825,50 +1882,45 @@ def server_description_count(): gc.collect() with client_knobs(min_heartbeat_interval=0.003): - client = self.simple_client( + client = await simple_client( "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 ) initial_count = server_description_count() - with self.assertRaises(ServerSelectionTimeoutError): + with pytest.raises(ServerSelectionTimeoutError): await client.test.test.find_one() gc.collect() final_count = server_description_count() - # If a bug like PYTHON-2433 is reintroduced then too many - # ServerDescriptions will be kept alive and this test will fail: - # AssertionError: 19 != 46 within 15 delta (27 difference) - # On Python 3.11 we seem to get more of a delta. - self.assertAlmostEqual(initial_count, final_count, delta=20) - - @async_client_context.require_failCommand_fail_point - async def test_network_error_message(self): - client = await self.async_single_client(retryReads=False) + assert pytest.approx(initial_count, abs=20) == final_count + + @pytest.mark.usefixtures("require_failCommand_fail_point") + async def test_network_error_message(self, async_single_client): + client = await async_single_client(retryReads=False) await client.admin.command("ping") # connect async with self.fail_point( - {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} + client, + {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}, ): assert await client.address is not None expected = "{}:{}: ".format(*(await client.address)) - with self.assertRaisesRegex(AutoReconnect, expected): + with pytest.raises(AutoReconnect, match=expected): await client.pymongo_test.test.find_one({}) - @unittest.skipIf("PyPy" in sys.version, "PYTHON-2938 could fail on PyPy") - async def test_process_periodic_tasks(self): - client = await self.async_rs_or_single_client() + @pytest.mark.skipif("PyPy" in sys.version, reason="PYTHON-2938 could fail on PyPy") + async def test_process_periodic_tasks(self, async_rs_or_single_client): + client = await async_rs_or_single_client() coll = client.db.collection await coll.insert_many([{} for _ in range(5)]) cursor = coll.find(batch_size=2) await cursor.next() c_id = cursor.cursor_id - self.assertIsNotNone(c_id) + assert c_id is not None await client.close() - # Add cursor to kill cursors queue del cursor await async_wait_until( - lambda: client._kill_cursors_queue, - "waited for cursor to be added to queue", + lambda: client._kill_cursors_queue, "waited for cursor to be added to queue" ) await client._process_periodic_tasks() # This must not raise or print any exceptions - with self.assertRaises(InvalidOperation): + with pytest.raises(InvalidOperation): await coll.insert_many([{} for _ in range(5)]) async def test_service_name_from_kwargs(self): @@ -1877,82 +1929,79 @@ async def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) - self.assertEqual(client._topology_settings.srv_service_name, "customname") + assert client._topology_settings.srv_service_name == "customname" + client = AsyncMongoClient( - "mongodb+srv://user:password@test22.test.build.10gen.cc" - "/?srvServiceName=shouldbeoverriden", + "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=shouldbeoverriden", srvServiceName="customname", connect=False, ) - self.assertEqual(client._topology_settings.srv_service_name, "customname") + assert client._topology_settings.srv_service_name == "customname" + client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", connect=False, ) - self.assertEqual(client._topology_settings.srv_service_name, "customname") + assert client._topology_settings.srv_service_name == "customname" - async def test_srv_max_hosts_kwarg(self): - client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") - self.assertGreater(len(client.topology_description.server_descriptions()), 1) - client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) - self.assertEqual(len(client.topology_description.server_descriptions()), 1) - client = self.simple_client( + async def test_srv_max_hosts_kwarg(self, simple_client): + client = await simple_client("mongodb+srv://test1.test.build.10gen.cc/") + assert len(client.topology_description.server_descriptions()) > 1 + + client = await simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + assert len(client.topology_description.server_descriptions()) == 1 + + client = await simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) - self.assertEqual(len(client.topology_description.server_descriptions()), 2) + assert len(client.topology_description.server_descriptions()) == 2 - @unittest.skipIf( - async_client_context.load_balancer or async_client_context.serverless, - "loadBalanced clients do not run SDAM", - ) - @unittest.skipIf(sys.platform == "win32", "Windows does not support SIGSTOP") - @async_client_context.require_sync - def test_sigstop_sigcont(self): + @pytest.mark.skipif(sys.platform == "win32", reason="Windows does not support SIGSTOP") + @pytest.mark.usefixtures("require_sdam") + @pytest.mark.usefixtures("require_sync") + def test_sigstop_sigcont(self, async_client_context_fixture): test_dir = os.path.dirname(os.path.realpath(__file__)) script = os.path.join(test_dir, "sigstop_sigcont.py") - p = subprocess.Popen( - [sys.executable, script, async_client_context.uri], + with subprocess.Popen( + [sys.executable, script, async_client_context_fixture.uri], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - ) - self.addCleanup(p.wait, timeout=1) - self.addCleanup(p.kill) - time.sleep(1) - # Stop the child, sleep for twice the streaming timeout - # (heartbeatFrequencyMS + connectTimeoutMS), and restart. - os.kill(p.pid, signal.SIGSTOP) - time.sleep(2) - os.kill(p.pid, signal.SIGCONT) - time.sleep(0.5) - # Tell the script to exit gracefully. - outs, _ = p.communicate(input=b"q\n", timeout=10) - self.assertTrue(outs) - log_output = outs.decode("utf-8") - self.assertIn("TEST STARTED", log_output) - self.assertIn("ServerHeartbeatStartedEvent", log_output) - self.assertIn("ServerHeartbeatSucceededEvent", log_output) - self.assertIn("TEST COMPLETED", log_output) - self.assertNotIn("ServerHeartbeatFailedEvent", log_output) - - async def _test_handshake(self, env_vars, expected_env): + ) as p: + time.sleep(1) + os.kill(p.pid, signal.SIGSTOP) + time.sleep(2) + os.kill(p.pid, signal.SIGCONT) + time.sleep(0.5) + outs, _ = p.communicate(input=b"q\n", timeout=10) + assert outs + log_output = outs.decode("utf-8") + assert "TEST STARTED" in log_output + assert "ServerHeartbeatStartedEvent" in log_output + assert "ServerHeartbeatSucceededEvent" in log_output + assert "TEST COMPLETED" in log_output + assert "ServerHeartbeatFailedEvent" not in log_output + + async def _test_handshake(self, env_vars, expected_env, async_rs_or_single_client): with patch.dict("os.environ", env_vars): metadata = copy.deepcopy(_METADATA) if has_c(): metadata["driver"]["name"] = "PyMongo|c|async" else: metadata["driver"]["name"] = "PyMongo|async" + if expected_env is not None: metadata["env"] = expected_env if "AWS_REGION" not in env_vars: os.environ["AWS_REGION"] = "" - client = await self.async_rs_or_single_client(serverSelectionTimeoutMS=10000) + + client = await async_rs_or_single_client(serverSelectionTimeoutMS=10000) await client.admin.command("ping") options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + assert options.pool_options.metadata == metadata - async def test_handshake_01_aws(self): + async def test_handshake_01_aws(self, async_rs_or_single_client): await self._test_handshake( { "AWS_EXECUTION_ENV": "AWS_Lambda_python3.9", @@ -1960,12 +2009,18 @@ async def test_handshake_01_aws(self): "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "1024", }, {"name": "aws.lambda", "region": "us-east-2", "memory_mb": 1024}, + async_rs_or_single_client, ) - async def test_handshake_02_azure(self): - await self._test_handshake({"FUNCTIONS_WORKER_RUNTIME": "python"}, {"name": "azure.func"}) + async def test_handshake_02_azure(self, async_rs_or_single_client): + await self._test_handshake( + {"FUNCTIONS_WORKER_RUNTIME": "python"}, + {"name": "azure.func"}, + async_rs_or_single_client, + ) - async def test_handshake_03_gcp(self): + async def test_handshake_03_gcp(self, async_rs_or_single_client): + # Regular case with environment variables. await self._test_handshake( { "K_SERVICE": "servicename", @@ -1974,7 +2029,9 @@ async def test_handshake_03_gcp(self): "FUNCTION_REGION": "us-central1", }, {"name": "gcp.func", "region": "us-central1", "memory_mb": 1024, "timeout_sec": 60}, + async_rs_or_single_client, ) + # Extra case for FUNCTION_NAME. await self._test_handshake( { @@ -1984,45 +2041,52 @@ async def test_handshake_03_gcp(self): "FUNCTION_REGION": "us-central1", }, {"name": "gcp.func", "region": "us-central1", "memory_mb": 1024, "timeout_sec": 60}, + async_rs_or_single_client, ) - async def test_handshake_04_vercel(self): + async def test_handshake_04_vercel(self, async_rs_or_single_client): await self._test_handshake( - {"VERCEL": "1", "VERCEL_REGION": "cdg1"}, {"name": "vercel", "region": "cdg1"} + {"VERCEL": "1", "VERCEL_REGION": "cdg1"}, + {"name": "vercel", "region": "cdg1"}, + async_rs_or_single_client, ) - async def test_handshake_05_multiple(self): + async def test_handshake_05_multiple(self, async_rs_or_single_client): await self._test_handshake( {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9", "FUNCTIONS_WORKER_RUNTIME": "python"}, None, + async_rs_or_single_client, ) - # Extra cases for other combos. + await self._test_handshake( {"FUNCTIONS_WORKER_RUNTIME": "python", "K_SERVICE": "servicename"}, None, + async_rs_or_single_client, ) - await self._test_handshake({"K_SERVICE": "servicename", "VERCEL": "1"}, None) - async def test_handshake_06_region_too_long(self): + await self._test_handshake( + {"K_SERVICE": "servicename", "VERCEL": "1"}, None, async_rs_or_single_client + ) + + async def test_handshake_06_region_too_long(self, async_rs_or_single_client): await self._test_handshake( {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9", "AWS_REGION": "a" * 512}, {"name": "aws.lambda"}, + async_rs_or_single_client, ) - async def test_handshake_07_memory_invalid_int(self): + async def test_handshake_07_memory_invalid_int(self, async_rs_or_single_client): await self._test_handshake( {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9", "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "big"}, {"name": "aws.lambda"}, + async_rs_or_single_client, ) - async def test_handshake_08_invalid_aws_ec2(self): + async def test_handshake_08_invalid_aws_ec2(self, async_rs_or_single_client): # AWS_EXECUTION_ENV needs to start with "AWS_Lambda_". - await self._test_handshake( - {"AWS_EXECUTION_ENV": "EC2"}, - None, - ) + await self._test_handshake({"AWS_EXECUTION_ENV": "EC2"}, None, async_rs_or_single_client) - async def test_handshake_09_container_with_provider(self): + async def test_handshake_09_container_with_provider(self, async_rs_or_single_client): await self._test_handshake( { ENV_VAR_K8S: "1", @@ -2036,102 +2100,96 @@ async def test_handshake_09_container_with_provider(self): "region": "us-east-1", "memory_mb": 256, }, + async_rs_or_single_client, ) - def test_dict_hints(self): - self.db.t.find(hint={"x": 1}) + async def test_dict_hints(self, async_client_context_fixture): + async_client_context_fixture.client.db.t.find(hint={"x": 1}) - def test_dict_hints_sort(self): - result = self.db.t.find() + async def test_dict_hints_sort(self, async_client_context_fixture): + result = async_client_context_fixture.client.db.t.find() result.sort({"x": 1}) + async_client_context_fixture.client.db.t.find(sort={"x": 1}) - self.db.t.find(sort={"x": 1}) - - async def test_dict_hints_create_index(self): - await self.db.t.create_index({"x": pymongo.ASCENDING}) + async def test_dict_hints_create_index(self, async_client_context_fixture): + await async_client_context_fixture.client.db.t.create_index({"x": pymongo.ASCENDING}) - async def test_legacy_java_uuid_roundtrip(self): + async def test_legacy_java_uuid_roundtrip(self, async_client_context_fixture): data = BinaryData.java_data docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY)) - await async_client_context.client.pymongo_test.drop_collection("java_uuid") - db = async_client_context.client.pymongo_test + await async_client_context_fixture.client.pymongo_test.drop_collection("java_uuid") + db = async_client_context_fixture.client.pymongo_test coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY)) await coll.insert_many(docs) - self.assertEqual(5, await coll.count_documents({})) + assert await coll.count_documents({}) == 5 async for d in coll.find(): - self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) + assert d["newguid"] == uuid.UUID(d["newguidstring"]) coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) async for d in coll.find(): - self.assertNotEqual(d["newguid"], d["newguidstring"]) - await async_client_context.client.pymongo_test.drop_collection("java_uuid") + assert d["newguid"] != d["newguidstring"] + await async_client_context_fixture.client.pymongo_test.drop_collection("java_uuid") - async def test_legacy_csharp_uuid_roundtrip(self): + async def test_legacy_csharp_uuid_roundtrip(self, async_client_context_fixture): data = BinaryData.csharp_data docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY)) - await async_client_context.client.pymongo_test.drop_collection("csharp_uuid") - db = async_client_context.client.pymongo_test + await async_client_context_fixture.client.pymongo_test.drop_collection("csharp_uuid") + db = async_client_context_fixture.client.pymongo_test coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY)) await coll.insert_many(docs) - self.assertEqual(5, await coll.count_documents({})) + assert await coll.count_documents({}) == 5 async for d in coll.find(): - self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) + assert d["newguid"] == uuid.UUID(d["newguidstring"]) coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) async for d in coll.find(): - self.assertNotEqual(d["newguid"], d["newguidstring"]) - await async_client_context.client.pymongo_test.drop_collection("csharp_uuid") + assert d["newguid"] != d["newguidstring"] + await async_client_context_fixture.client.pymongo_test.drop_collection("csharp_uuid") - async def test_uri_to_uuid(self): + async def test_uri_to_uuid(self, async_single_client): uri = "mongodb://foo/?uuidrepresentation=csharpLegacy" - client = await self.async_single_client(uri, connect=False) - self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY) + client = await async_single_client(uri, connect=False) + assert client.pymongo_test.test.codec_options.uuid_representation == CSHARP_LEGACY - async def test_uuid_queries(self): - db = async_client_context.client.pymongo_test + async def test_uuid_queries(self, async_client_context_fixture): + db = async_client_context_fixture.client.pymongo_test coll = db.test await coll.drop() uu = uuid.uuid4() await coll.insert_one({"uuid": Binary(uu.bytes, 3)}) - self.assertEqual(1, await coll.count_documents({})) + assert await coll.count_documents({}) == 1 - # Test regular UUID queries (using subtype 4). coll = db.get_collection( "test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD) ) - self.assertEqual(0, await coll.count_documents({"uuid": uu})) + assert await coll.count_documents({"uuid": uu}) == 0 await coll.insert_one({"uuid": uu}) - self.assertEqual(2, await coll.count_documents({})) - docs = await coll.find({"uuid": uu}).to_list() - self.assertEqual(1, len(docs)) - self.assertEqual(uu, docs[0]["uuid"]) + assert await coll.count_documents({}) == 2 + docs = await coll.find({"uuid": uu}).to_list(length=1) + assert len(docs) == 1 + assert docs[0]["uuid"] == uu - # Test both. uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY) predicate = {"uuid": {"$in": [uu, uu_legacy]}} - self.assertEqual(2, await coll.count_documents(predicate)) - docs = await coll.find(predicate).to_list() - self.assertEqual(2, len(docs)) + assert await coll.count_documents(predicate) == 2 + docs = await coll.find(predicate).to_list(length=2) + assert len(docs) == 2 await coll.drop() -class TestExhaustCursor(AsyncIntegrationTest): - """Test that clients properly handle errors from exhaust cursors.""" - - def setUp(self): - super().setUp() - if async_client_context.is_mongos: - raise SkipTest("mongos doesn't support exhaust, SERVER-2627") - - async def test_exhaust_query_server_error(self): +@pytest.mark.usefixtures("require_no_mongos") +@pytest.mark.usefixtures("require_integration") +@pytest.mark.integration +class TestExhaustCursor(AsyncPyMongoTestCasePyTest): + async def test_exhaust_query_server_error(self, async_rs_or_single_client): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = await connected(await self.async_rs_or_single_client(maxPoolSize=1)) + client = await connected(await async_rs_or_single_client(maxPoolSize=1)) collection = client.pymongo_test.test pool = await async_get_pool(client) @@ -2143,23 +2201,22 @@ async def test_exhaust_query_server_error(self): SON([("$query", {}), ("$orderby", True)]), cursor_type=CursorType.EXHAUST ) - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): await cursor.next() - self.assertFalse(conn.closed) + assert not conn.closed # The socket was checked in and the semaphore was decremented. - self.assertIn(conn, pool.conns) - self.assertEqual(0, pool.requests) + assert conn in pool.conns + assert pool.requests == 0 - async def test_exhaust_getmore_server_error(self): + async def test_exhaust_getmore_server_error(self, async_rs_or_single_client): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = await self.async_rs_or_single_client(maxPoolSize=1) + client = await async_rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test await collection.drop() await collection.insert_many([{} for _ in range(200)]) - self.addAsyncCleanup(async_client_context.client.pymongo_test.test.drop) pool = await async_get_pool(client) pool._check_interval_seconds = None # Never check. @@ -2181,21 +2238,19 @@ async def receive_message(request_id): return message._OpReply.unpack(msg) conn.receive_message = receive_message - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): await cursor.to_list() # Unpatch the instance. del conn.receive_message # The socket is returned to the pool and it still works. - self.assertEqual(200, await collection.count_documents({})) - self.assertIn(conn, pool.conns) + assert 200 == await collection.count_documents({}) + assert conn in pool.conns - async def test_exhaust_query_network_error(self): + async def test_exhaust_query_network_error(self, async_rs_or_single_client): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = await connected( - await self.async_rs_or_single_client(maxPoolSize=1, retryReads=False) - ) + client = await connected(await async_rs_or_single_client(maxPoolSize=1, retryReads=False)) collection = client.pymongo_test.test pool = await async_get_pool(client) pool._check_interval_seconds = None # Never check. @@ -2205,18 +2260,18 @@ async def test_exhaust_query_network_error(self): conn.conn.close() cursor = collection.find(cursor_type=CursorType.EXHAUST) - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): await cursor.next() - self.assertTrue(conn.closed) + assert conn.closed # The socket was closed and the semaphore was decremented. - self.assertNotIn(conn, pool.conns) - self.assertEqual(0, pool.requests) + assert conn not in pool.conns + assert 0 == pool.requests - async def test_exhaust_getmore_network_error(self): + async def test_exhaust_getmore_network_error(self, async_rs_or_single_client): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = await self.async_rs_or_single_client(maxPoolSize=1) + client = await async_rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test await collection.drop() await collection.insert_many([{} for _ in range(200)]) # More than one batch. @@ -2233,39 +2288,39 @@ async def test_exhaust_getmore_network_error(self): conn.conn.close() # A getmore fails. - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): await cursor.to_list() - self.assertTrue(conn.closed) + assert conn.closed await async_wait_until( lambda: len(client._kill_cursors_queue) == 0, "waited for all killCursor requests to complete", ) # The socket was closed and the semaphore was decremented. - self.assertNotIn(conn, pool.conns) - self.assertEqual(0, pool.requests) + assert conn not in pool.conns + assert 0 == pool.requests - @async_client_context.require_sync - def test_gevent_task(self): + @pytest.mark.usefixtures("require_sync") + def test_gevent_task(self, async_client_context_fixture): if not gevent_monkey_patched(): - raise SkipTest("Must be running monkey patched by gevent") + pytest.skip("Must be running monkey patched by gevent") from gevent import spawn def poller(): while True: - async_client_context.client.pymongo_test.test.insert_one({}) + async_client_context_fixture.client.pymongo_test.test.insert_one({}) task = spawn(poller) task.kill() - self.assertTrue(task.dead) + assert task.dead - @async_client_context.require_sync - def test_gevent_timeout(self): + @pytest.mark.usefixtures("require_sync") + def test_gevent_timeout(self, async_rs_or_single_client): if not gevent_monkey_patched(): - raise SkipTest("Must be running monkey patched by gevent") + pytest.skip("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = self.async_rs_or_single_client(maxPoolSize=1) + client = async_rs_or_single_client(maxPoolSize=1) coll = client.pymongo_test.test coll.insert_one({}) @@ -2286,19 +2341,19 @@ def timeout_task(): tt = spawn(timeout_task) tt.join(15) ct.join(15) - self.assertTrue(tt.dead) - self.assertTrue(ct.dead) - self.assertIsNone(tt.get()) - self.assertIsNone(ct.get()) + assert tt.dead + assert ct.dead + assert tt.get() is None + assert ct.get() is None - @async_client_context.require_sync - def test_gevent_timeout_when_creating_connection(self): + @pytest.mark.usefixtures("require_sync") + def test_gevent_timeout_when_creating_connection(self, async_rs_or_single_client): if not gevent_monkey_patched(): - raise SkipTest("Must be running monkey patched by gevent") + pytest.skip("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = self.async_rs_or_single_client() - self.addCleanup(client.close) + client = async_rs_or_single_client() + coll = client.pymongo_test.test pool = async_get_pool(client) @@ -2321,23 +2376,35 @@ def timeout_task(): tt.join(10) # Assert that we got our active_sockets count back - self.assertEqual(pool.active_sockets, 0) + assert pool.active_sockets == 0 # Assert the greenlet is dead - self.assertTrue(tt.dead) + assert tt.dead # Assert that the Timeout was raised all the way to the try - self.assertTrue(tt.get()) + assert tt.get() # Unpatch the instance. del pool.connect -class TestClientLazyConnect(AsyncIntegrationTest): +@pytest.mark.usefixtures("require_sync") +@pytest.mark.usefixtures("require_integration") +@pytest.mark.integration +class TestClientLazyConnect: """Test concurrent operations on a lazily-connecting MongoClient.""" - def _get_client(self): - return self.async_rs_or_single_client(connect=False) + @pytest.fixture + def _get_client(self, async_rs_or_single_client): + clients = [] + + def _make_client(): + client = async_rs_or_single_client(connect=False) + clients.append(client) + return client + + yield _make_client + for client in clients: + client.close() - @async_client_context.require_sync - def test_insert_one(self): + def test_insert_one(self, _get_client, async_client_context_fixture): def reset(collection): collection.drop() @@ -2345,12 +2412,11 @@ def insert_one(collection, _): collection.insert_one({}) def test(collection): - self.assertEqual(NTHREADS, collection.count_documents({})) + assert NTHREADS == collection.count_documents({}) - lazy_client_trial(reset, insert_one, test, self._get_client) + lazy_client_trial(reset, insert_one, test, _get_client, async_client_context_fixture) - @async_client_context.require_sync - def test_update_one(self): + def test_update_one(self, _get_client, async_client_context_fixture): def reset(collection): collection.drop() collection.insert_one({"i": 0}) @@ -2360,12 +2426,11 @@ def update_one(collection, _): collection.update_one({}, {"$inc": {"i": 1}}) def test(collection): - self.assertEqual(NTHREADS, collection.find_one()["i"]) + assert NTHREADS == collection.find_one()["i"] - lazy_client_trial(reset, update_one, test, self._get_client) + lazy_client_trial(reset, update_one, test, _get_client, async_client_context_fixture) - @async_client_context.require_sync - def test_delete_one(self): + def test_delete_one(self, _get_client, async_client_context_fixture): def reset(collection): collection.drop() collection.insert_many([{"i": i} for i in range(NTHREADS)]) @@ -2374,12 +2439,11 @@ def delete_one(collection, i): collection.delete_one({"i": i}) def test(collection): - self.assertEqual(0, collection.count_documents({})) + assert 0 == collection.count_documents({}) - lazy_client_trial(reset, delete_one, test, self._get_client) + lazy_client_trial(reset, delete_one, test, _get_client, async_client_context_fixture) - @async_client_context.require_sync - def test_find_one(self): + def test_find_one(self, _get_client, async_client_context_fixture): results: list = [] def reset(collection): @@ -2391,14 +2455,23 @@ def find_one(collection, _): results.append(collection.find_one()) def test(collection): - self.assertEqual(NTHREADS, len(results)) + assert NTHREADS == len(results) - lazy_client_trial(reset, find_one, test, self._get_client) + lazy_client_trial(reset, find_one, test, _get_client, async_client_context_fixture) -class TestMongoClientFailover(AsyncMockClientTest): - async def test_discover_primary(self): - c = await AsyncMockClient.get_async_mock_client( +@pytest.mark.usefixtures("require_no_load_balancer") +@pytest.mark.unit +class TestMongoClientFailover: + @pytest.fixture(scope="class", autouse=True) + def _client_knobs(self): + knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) + knobs.enable() + yield knobs + knobs.disable() + + async def test_discover_primary(self, async_mock_client): + c = await async_mock_client( standalones=[], members=["a:1", "b:2", "c:3"], mongoses=[], @@ -2406,11 +2479,10 @@ async def test_discover_primary(self): replicaSet="rs", heartbeatFrequencyMS=500, ) - self.addAsyncCleanup(c.close) await async_wait_until(lambda: len(c.nodes) == 3, "connect") - self.assertEqual(await c.address, ("a", 1)) + assert await c.address == ("a", 1) # Fail over. c.kill_host("a:1") c.mock_primary = "b:2" @@ -2420,11 +2492,11 @@ async def predicate(): await async_wait_until(predicate, "wait for server address to be updated") # a:1 not longer in nodes. - self.assertLess(len(c.nodes), 3) + assert len(c.nodes) < 3 - async def test_reconnect(self): + async def test_reconnect(self, async_mock_client): # Verify the node list isn't forgotten during a network failure. - c = await AsyncMockClient.get_async_mock_client( + c = await async_mock_client( standalones=[], members=["a:1", "b:2", "c:3"], mongoses=[], @@ -2433,7 +2505,6 @@ async def test_reconnect(self): retryReads=False, serverSelectionTimeoutMS=1000, ) - self.addAsyncCleanup(c.close) await async_wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2442,10 +2513,10 @@ async def test_reconnect(self): c.kill_host("b:2") c.kill_host("c:3") - # AsyncMongoClient discovers it's alone. The first attempt raises either + # MongoClient discovers it's alone. The first attempt raises either # ServerSelectionTimeoutError or AutoReconnect (from # AsyncMockPool.get_socket). - with self.assertRaises(AutoReconnect): + with pytest.raises(AutoReconnect): await c.db.collection.find_one() # But it can reconnect. @@ -2453,14 +2524,14 @@ async def test_reconnect(self): await (await c._get_topology()).select_servers( writable_server_selector, _Op.TEST, server_selection_timeout=10 ) - self.assertEqual(await c.address, ("a", 1)) + assert await c.address == ("a", 1) - async def _test_network_error(self, operation_callback): + async def _test_network_error(self, async_mock_client, operation_callback): # Verify only the disconnected server is reset by a network failure. # Disable background refresh. with client_knobs(heartbeat_frequency=999999): - c = AsyncMockClient( + c = await async_mock_client( standalones=[], members=["a:1", "b:2"], mongoses=[], @@ -2471,8 +2542,6 @@ async def _test_network_error(self, operation_callback): serverSelectionTimeoutMS=1000, ) - self.addAsyncCleanup(c.close) - # Set host-specific information so we can test whether it is reset. c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION) c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1) @@ -2481,62 +2550,63 @@ async def _test_network_error(self, operation_callback): c.kill_host("a:1") - # AsyncMongoClient is disconnected from the primary. This raises either + # MongoClient is disconnected from the primary. This raises either # ServerSelectionTimeoutError or AutoReconnect (from # MockPool.get_socket). - with self.assertRaises(AutoReconnect): + with pytest.raises(AutoReconnect): await operation_callback(c) # The primary's description is reset. server_a = (await c._get_topology()).get_server_by_address(("a", 1)) sd_a = server_a.description - self.assertEqual(SERVER_TYPE.Unknown, sd_a.server_type) - self.assertEqual(0, sd_a.min_wire_version) - self.assertEqual(0, sd_a.max_wire_version) + assert SERVER_TYPE.Unknown == sd_a.server_type + assert 0 == sd_a.min_wire_version + assert 0 == sd_a.max_wire_version # ...but not the secondary's. server_b = (await c._get_topology()).get_server_by_address(("b", 2)) sd_b = server_b.description - self.assertEqual(SERVER_TYPE.RSSecondary, sd_b.server_type) - self.assertEqual(2, sd_b.min_wire_version) - self.assertEqual(MIN_SUPPORTED_WIRE_VERSION + 1, sd_b.max_wire_version) + assert sd_b.server_type == SERVER_TYPE.RSSecondary + assert sd_b.min_wire_version == 2 + assert sd_b.max_wire_version == MIN_SUPPORTED_WIRE_VERSION + 1 - async def test_network_error_on_query(self): + async def test_network_error_on_query(self, async_mock_client): async def callback(client): return await client.db.collection.find_one() - await self._test_network_error(callback) + await self._test_network_error(async_mock_client, callback) - async def test_network_error_on_insert(self): + async def test_network_error_on_insert(self, async_mock_client): async def callback(client): return await client.db.collection.insert_one({}) - await self._test_network_error(callback) + await self._test_network_error(async_mock_client, callback) - async def test_network_error_on_update(self): + async def test_network_error_on_update(self, async_mock_client): async def callback(client): return await client.db.collection.update_one({}, {"$unset": "x"}) - await self._test_network_error(callback) + await self._test_network_error(async_mock_client, callback) - async def test_network_error_on_replace(self): + async def test_network_error_on_replace(self, async_mock_client): async def callback(client): return await client.db.collection.replace_one({}, {}) - await self._test_network_error(callback) + await self._test_network_error(async_mock_client, callback) - async def test_network_error_on_delete(self): + async def test_network_error_on_delete(self, async_mock_client): async def callback(client): return await client.db.collection.delete_many({}) - await self._test_network_error(callback) + await self._test_network_error(async_mock_client, callback) -class TestClientPool(AsyncMockClientTest): - @async_client_context.require_connection - async def test_rs_client_does_not_maintain_pool_to_arbiters(self): +@pytest.mark.usefixtures("require_integration") +@pytest.mark.integration +class TestClientPool: + async def test_rs_client_does_not_maintain_pool_to_arbiters(self, async_mock_client): listener = CMAPListener() - c = await AsyncMockClient.get_async_mock_client( + c = await async_mock_client( standalones=[], members=["a:1", "b:2", "c:3", "d:4"], mongoses=[], @@ -2547,27 +2617,21 @@ async def test_rs_client_does_not_maintain_pool_to_arbiters(self): minPoolSize=1, # minPoolSize event_listeners=[listener], ) - self.addAsyncCleanup(c.close) await async_wait_until(lambda: len(c.nodes) == 3, "connect") - self.assertEqual(await c.address, ("a", 1)) - self.assertEqual(await c.arbiters, {("c", 3)}) - # Assert that we create 2 and only 2 pooled connections. + assert await c.address == ("a", 1) + assert await c.arbiters == {("c", 3)} await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 2) - self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2) - # Assert that we do not create connections to arbiters. + assert listener.event_count(monitoring.ConnectionCreatedEvent) == 2 arbiter = c._topology.get_server_by_address(("c", 3)) - self.assertFalse(arbiter.pool.conns) - # Assert that we do not create connections to unknown servers. + assert not arbiter.pool.conns arbiter = c._topology.get_server_by_address(("d", 4)) - self.assertFalse(arbiter.pool.conns) - # Arbiter pool is not marked ready. - self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 2) + assert not arbiter.pool.conns + assert listener.event_count(monitoring.PoolReadyEvent) == 2 - @async_client_context.require_connection - async def test_direct_client_maintains_pool_to_arbiter(self): + async def test_direct_client_maintains_pool_to_arbiter(self, async_mock_client): listener = CMAPListener() - c = await AsyncMockClient.get_async_mock_client( + c = await async_mock_client( standalones=[], members=["a:1", "b:2", "c:3"], mongoses=[], @@ -2577,18 +2641,11 @@ async def test_direct_client_maintains_pool_to_arbiter(self): minPoolSize=1, # minPoolSize event_listeners=[listener], ) - self.addAsyncCleanup(c.close) await async_wait_until(lambda: len(c.nodes) == 1, "connect") - self.assertEqual(await c.address, ("c", 3)) - # Assert that we create 1 pooled connection. + assert await c.address == ("c", 3) await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1) - self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1) + assert listener.event_count(monitoring.ConnectionCreatedEvent) == 1 arbiter = c._topology.get_server_by_address(("c", 3)) - self.assertEqual(len(arbiter.pool.conns), 1) - # Arbiter pool is marked ready. - self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 1) - - -if __name__ == "__main__": - unittest.main() + assert len(arbiter.pool.conns) == 1 + assert listener.event_count(monitoring.PoolReadyEvent) == 1 diff --git a/test/conftest.py b/test/conftest.py index 91fad28d0a..a55116e788 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,10 +2,26 @@ import asyncio import sys -from test import pytest_conf, setup, teardown +from test import ( + MONGODB_API_VERSION, + ClientContext, + _connection_string, + db_pwd, + db_user, + pytest_conf, + setup, + teardown, +) +from test.pymongo_mocks import MockClient +from test.utils import FunctionCallRecorder +from typing import Any, Callable import pytest +import pymongo +from pymongo import MongoClient +from pymongo.uri_parser import parse_uri + _IS_SYNC = True @@ -20,11 +36,335 @@ def event_loop_policy(): return asyncio.get_event_loop_policy() -@pytest.fixture(scope="package", autouse=True) +@pytest.fixture(scope="session") +def client_context_fixture(): + client = ClientContext() + client.init() + yield client + if client.client is not None: + if not client.is_data_lake: + client.client.drop_database("pymongo-pooling-tests") + client.client.drop_database("pymongo_test") + client.client.drop_database("pymongo_test1") + client.client.drop_database("pymongo_test2") + client.client.drop_database("pymongo_test_mike") + client.client.drop_database("pymongo_test_bernie") + client.client.close() + + +@pytest.fixture +def require_integration(client_context_fixture): + if not client_context_fixture.connected: + pytest.fail("Integration tests require a MongoDB server") + + +@pytest.fixture(scope="session") +def test_environment(client_context_fixture): + requirements = {} + requirements["SUPPORT_TRANSACTIONS"] = client_context_fixture.supports_transactions() + requirements["IS_DATA_LAKE"] = client_context_fixture.is_data_lake + requirements["IS_SYNC"] = _IS_SYNC + requirements["IS_SYNC"] = _IS_SYNC + requirements["REQUIRE_API_VERSION"] = MONGODB_API_VERSION + requirements[ + "SUPPORTS_FAILCOMMAND_FAIL_POINT" + ] = client_context_fixture.supports_failCommand_fail_point + requirements["IS_NOT_MMAP"] = client_context_fixture.is_not_mmap + requirements["SERVER_VERSION"] = client_context_fixture.version + requirements["AUTH_ENABLED"] = client_context_fixture.auth_enabled + requirements["FIPS_ENABLED"] = client_context_fixture.fips_enabled + requirements["IS_RS"] = client_context_fixture.is_rs + requirements["MONGOSES"] = len(client_context_fixture.mongoses) + requirements["SECONDARIES_COUNT"] = client_context_fixture.secondaries_count + requirements["SECONDARY_READ_PREF"] = client_context_fixture.supports_secondary_read_pref + requirements["HAS_IPV6"] = client_context_fixture.has_ipv6 + requirements["IS_SERVERLESS"] = client_context_fixture.serverless + requirements["IS_LOAD_BALANCER"] = client_context_fixture.load_balancer + requirements["TEST_COMMANDS_ENABLED"] = client_context_fixture.test_commands_enabled + requirements["IS_TLS"] = client_context_fixture.tls + requirements["IS_TLS_CERT"] = client_context_fixture.tlsCertificateKeyFile + requirements["SERVER_IS_RESOLVEABLE"] = client_context_fixture.server_is_resolvable + requirements["SESSIONS_ENABLED"] = client_context_fixture.sessions_enabled + requirements["SUPPORTS_RETRYABLE_WRITES"] = client_context_fixture.supports_retryable_writes() + yield requirements + + +@pytest.fixture +def require_auth(test_environment): + if not test_environment["AUTH_ENABLED"]: + pytest.skip("Authentication is not enabled on the server") + + +@pytest.fixture +def require_no_fips(test_environment): + if test_environment["FIPS_ENABLED"]: + pytest.skip("Test cannot run on a FIPS-enabled host") + + +@pytest.fixture +def require_no_tls(test_environment): + if test_environment["IS_TLS"]: + pytest.skip("Must be able to connect without TLS") + + +@pytest.fixture +def require_ipv6(test_environment): + if not test_environment["HAS_IPV6"]: + pytest.skip("No IPv6") + + +@pytest.fixture +def require_sync(test_environment): + if not _IS_SYNC: + pytest.skip("This test only works with the synchronous API") + + +@pytest.fixture +def require_no_mongos(test_environment): + if test_environment["MONGOSES"]: + pytest.skip("Must be connected to a mongod, not a mongos") + + +@pytest.fixture +def require_no_replica_set(test_environment): + if test_environment["IS_RS"]: + pytest.skip("Connected to a replica set, not a standalone mongod") + + +@pytest.fixture +def require_replica_set(test_environment): + if not test_environment["IS_RS"]: + pytest.skip("Not connected to a replica set") + + +@pytest.fixture +def require_sdam(test_environment): + if test_environment["IS_SERVERLESS"] or test_environment["IS_LOAD_BALANCER"]: + pytest.skip("loadBalanced and serverless clients do not run SDAM") + + +@pytest.fixture +def require_no_load_balancer(test_environment): + if test_environment["IS_LOAD_BALANCER"]: + pytest.skip("Must not be connected to a load balancer") + + +@pytest.fixture +def require_failCommand_fail_point(test_environment): + if not test_environment["SUPPORTS_FAILCOMMAND_FAIL_POINT"]: + pytest.skip("failCommand fail point must be supported") + + +@pytest.fixture(scope="session", autouse=True) def test_setup_and_teardown(): setup() yield teardown() +def _async_mongo_client( + client_context_fixture, host, port, authenticate=True, directConnection=None, **kwargs +): + """Create a new client over SSL/TLS if necessary.""" + host = host or client_context_fixture.host + port = port or client_context_fixture.port + client_options: dict = client_context_fixture.default_client_options.copy() + if client_context_fixture.replica_set_name and not directConnection: + client_options["replicaSet"] = client_context_fixture.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if client_context_fixture.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = MongoClient(uri, port, **client_options) + if client._options.connect: + client._connect() + return client + + +@pytest.fixture() +def single_client_noauth(client_context_fixture) -> Callable[..., MongoClient]: + """Make a direct connection. Don't authenticate.""" + clients = [] + + def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + client = _async_mongo_client( + client_context_fixture, h, p, authenticate=False, directConnection=True, **kwargs + ) + clients.append(client) + return client + + yield _make_client + for client in clients: + client.close() + + +@pytest.fixture() +def single_client(client_context_fixture) -> Callable[..., MongoClient]: + """Make a direct connection, and authenticate if necessary.""" + clients = [] + + def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + client = _async_mongo_client(client_context_fixture, h, p, directConnection=True, **kwargs) + clients.append(client) + return client + + yield _make_client + for client in clients: + client.close() + + +@pytest.fixture() +def rs_client_noauth(client_context_fixture) -> Callable[..., MongoClient]: + """Connect to the replica set. Don't authenticate.""" + clients = [] + + def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + client = _async_mongo_client(client_context_fixture, h, p, authenticate=False, **kwargs) + clients.append(client) + return client + + yield _make_client + for client in clients: + client.close() + + +@pytest.fixture() +def rs_client(client_context_fixture) -> Callable[..., MongoClient]: + """Connect to the replica set and authenticate if necessary.""" + clients = [] + + def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + client = _async_mongo_client(client_context_fixture, h, p, **kwargs) + clients.append(client) + return client + + yield _make_client + for client in clients: + client.close() + + +@pytest.fixture() +def rs_or_single_client_noauth(client_context_fixture) -> Callable[..., MongoClient]: + """Connect to the replica set if there is one, otherwise the standalone. + + Like rs_or_single_client, but does not authenticate. + """ + clients = [] + + def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + client = _async_mongo_client(client_context_fixture, h, p, authenticate=False, **kwargs) + clients.append(client) + return client + + yield _make_client + for client in clients: + client.close() + + +@pytest.fixture() +def rs_or_single_client(client_context_fixture) -> Callable[..., MongoClient]: + """Connect to the replica set if there is one, otherwise the standalone. + + Authenticates if necessary. + """ + clients = [] + + def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + client = _async_mongo_client(client_context_fixture, h, p, **kwargs) + clients.append(client) + return client + + yield _make_client + for client in clients: + client.close() + + +@pytest.fixture() +def simple_client() -> Callable[..., MongoClient]: + clients = [] + + def _make_client(h: Any = None, p: Any = None, **kwargs: Any): + if not h and not p: + client = MongoClient(**kwargs) + else: + client = MongoClient(h, p, **kwargs) + clients.append(client) + return client + + yield _make_client + for client in clients: + client.close() + + +@pytest.fixture(scope="function") +def patch_resolver(): + from pymongo.srv_resolver import _resolve + + patched_resolver = FunctionCallRecorder(_resolve) + pymongo.srv_resolver._resolve = patched_resolver + yield patched_resolver + pymongo.srv_resolver._resolve = _resolve + + +@pytest.fixture() +def mock_client(): + clients = [] + + def _make_client( + standalones, + members, + mongoses, + hello_hosts=None, + arbiters=None, + down_hosts=None, + *args, + **kwargs, + ): + client = MockClient.get_mock_client( + standalones, members, mongoses, hello_hosts, arbiters, down_hosts, *args, **kwargs + ) + clients.append(client) + return client + + yield _make_client + for client in clients: + client.close() + + +@pytest.fixture() +def remove_all_users_fixture(client_context_fixture, request): + db_name = request.param + yield + client_context_fixture.client[db_name].command( + "dropAllUsersFromDatabase", 1, writeConcern={"w": client_context_fixture.w} + ) + + +@pytest.fixture() +def drop_user_fixture(client_context_fixture, request): + db, user = request.param + yield + client_context_fixture.drop_user(db, user) + + +@pytest.fixture() +def drop_database_fixture(client_context_fixture, request): + db = request.param + yield + client_context_fixture.client.drop_database(db) + + pytest_collection_modifyitems = pytest_conf.pytest_collection_modifyitems diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 7662dc9682..d352aae070 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -165,7 +165,8 @@ def get_mock_client( standalones, members, mongoses, hello_hosts, arbiters, down_hosts, *args, **kwargs ) - c._connect() + if kwargs.get("connect", True): + c._connect() return c def kill_host(self, host): diff --git a/test/pytest_conf.py b/test/pytest_conf.py index a6e24cd9b1..1a198956a5 100644 --- a/test/pytest_conf.py +++ b/test/pytest_conf.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + def pytest_collection_modifyitems(items, config): # Markers that should overlap with the default markers. @@ -10,6 +12,6 @@ def pytest_collection_modifyitems(items, config): default_marker = "default_async" else: default_marker = "default" - markers = [m for m in item.iter_markers() if m not in overlap_markers] + markers = [m for m in item.iter_markers() if m.name not in overlap_markers] if not markers: item.add_marker(default_marker) diff --git a/test/test_client.py b/test/test_client.py index 2a33077f5f..35f2fe67b4 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -17,7 +17,6 @@ import _thread as thread import asyncio -import base64 import contextlib import copy import datetime @@ -46,24 +45,18 @@ from test import ( HAVE_IPADDRESS, - IntegrationTest, - MockClientTest, + PyMongoTestCasePyTest, SkipTest, - UnitTest, - client_context, client_knobs, connected, db_pwd, db_user, - remove_all_users, - unittest, ) -from test.pymongo_mocks import MockClient from test.test_binary import BinaryData from test.utils import ( NTHREADS, CMAPListener, - FunctionCallRecorder, + _default_pytest_mark, assertRaisesExactly, delay, get_pool, @@ -124,20 +117,19 @@ _IS_SYNC = True -class ClientUnitTest(UnitTest): - """MongoClient tests that don't require a server.""" +pytestmark = _default_pytest_mark(_IS_SYNC) - client: MongoClient - def setUp(self) -> None: - self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) - - @pytest.fixture(autouse=True) - def inject_fixtures(self, caplog): - self._caplog = caplog +@pytest.mark.unit +class TestClientUnitTest: + @pytest.fixture() + def client(self, rs_or_single_client) -> MongoClient: + client = rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) + yield client + client.close() - def test_keyword_arg_defaults(self): - client = self.simple_client( + def test_keyword_arg_defaults(self, simple_client): + client = simple_client( socketTimeoutMS=None, connectTimeoutMS=20000, waitQueueTimeoutMS=None, @@ -153,216 +145,249 @@ def test_keyword_arg_defaults(self): options = client.options pool_opts = options.pool_options - self.assertEqual(None, pool_opts.socket_timeout) + assert pool_opts.socket_timeout is None # socket.Socket.settimeout takes a float in seconds - self.assertEqual(20.0, pool_opts.connect_timeout) - self.assertEqual(None, pool_opts.wait_queue_timeout) - self.assertEqual(None, pool_opts._ssl_context) - self.assertEqual(None, options.replica_set_name) - self.assertEqual(ReadPreference.PRIMARY, client.read_preference) - self.assertAlmostEqual(12, client.options.server_selection_timeout) - - def test_connect_timeout(self): - client = self.simple_client(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) + assert 20.0 == pool_opts.connect_timeout + assert pool_opts.wait_queue_timeout is None + assert pool_opts._ssl_context is None + assert options.replica_set_name is None + assert client.read_preference == ReadPreference.PRIMARY + assert pytest.approx(client.options.server_selection_timeout, rel=1e-9) == 12 + + def test_connect_timeout(self, simple_client): + client = simple_client(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) pool_opts = client.options.pool_options - self.assertEqual(None, pool_opts.socket_timeout) - self.assertEqual(None, pool_opts.connect_timeout) + assert pool_opts.socket_timeout is None + assert pool_opts.connect_timeout is None - client = self.simple_client(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) + client = simple_client(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) pool_opts = client.options.pool_options - self.assertEqual(None, pool_opts.socket_timeout) - self.assertEqual(None, pool_opts.connect_timeout) + assert pool_opts.socket_timeout is None + assert pool_opts.connect_timeout is None - client = self.simple_client( + client = simple_client( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options - self.assertEqual(None, pool_opts.socket_timeout) - self.assertEqual(None, pool_opts.connect_timeout) + assert pool_opts.socket_timeout is None + assert pool_opts.connect_timeout is None def test_types(self): - self.assertRaises(TypeError, MongoClient, 1) - self.assertRaises(TypeError, MongoClient, 1.14) - self.assertRaises(TypeError, MongoClient, "localhost", "27017") - self.assertRaises(TypeError, MongoClient, "localhost", 1.14) - self.assertRaises(TypeError, MongoClient, "localhost", []) - - self.assertRaises(ConfigurationError, MongoClient, []) - - def test_max_pool_size_zero(self): - self.simple_client(maxPoolSize=0) + with pytest.raises(TypeError): + MongoClient(1) # type: ignore[arg-type] + with pytest.raises(TypeError): + MongoClient(1.14) # type: ignore[arg-type] + with pytest.raises(TypeError): + MongoClient("localhost", "27017") # type: ignore[arg-type] + with pytest.raises(TypeError): + MongoClient("localhost", 1.14) # type: ignore[arg-type] + with pytest.raises(TypeError): + MongoClient("localhost", []) # type: ignore[arg-type] + + with pytest.raises(ConfigurationError): + MongoClient([]) + + def test_max_pool_size_zero(self, simple_client): + simple_client(maxPoolSize=0) def test_uri_detection(self): - self.assertRaises(ConfigurationError, MongoClient, "/foo") - self.assertRaises(ConfigurationError, MongoClient, "://") - self.assertRaises(ConfigurationError, MongoClient, "foo/") - - def test_get_db(self): + with pytest.raises(ConfigurationError): + MongoClient("/foo") + with pytest.raises(ConfigurationError): + MongoClient("://") + with pytest.raises(ConfigurationError): + MongoClient("foo/") + + def test_get_db(self, client): def make_db(base, name): return base[name] - self.assertRaises(InvalidName, make_db, self.client, "") - self.assertRaises(InvalidName, make_db, self.client, "te$t") - self.assertRaises(InvalidName, make_db, self.client, "te.t") - self.assertRaises(InvalidName, make_db, self.client, "te\\t") - self.assertRaises(InvalidName, make_db, self.client, "te/t") - self.assertRaises(InvalidName, make_db, self.client, "te st") - - self.assertTrue(isinstance(self.client.test, Database)) - self.assertEqual(self.client.test, self.client["test"]) - self.assertEqual(self.client.test, Database(self.client, "test")) - - def test_get_database(self): + with pytest.raises(InvalidName): + make_db(client, "") + with pytest.raises(InvalidName): + make_db(client, "te$t") + with pytest.raises(InvalidName): + make_db(client, "te.t") + with pytest.raises(InvalidName): + make_db(client, "te\\t") + with pytest.raises(InvalidName): + make_db(client, "te/t") + with pytest.raises(InvalidName): + make_db(client, "te st") + # Type and equality assertions + assert isinstance(client.test, Database) + assert client.test == client["test"] + assert client.test == Database(client, "test") + + def test_get_database(self, client): codec_options = CodecOptions(tz_aware=True) write_concern = WriteConcern(w=2, j=True) - db = self.client.get_database("foo", codec_options, ReadPreference.SECONDARY, write_concern) - self.assertEqual("foo", db.name) - self.assertEqual(codec_options, db.codec_options) - self.assertEqual(ReadPreference.SECONDARY, db.read_preference) - self.assertEqual(write_concern, db.write_concern) + db = client.get_database("foo", codec_options, ReadPreference.SECONDARY, write_concern) + assert db.name == "foo" + assert db.codec_options == codec_options + assert db.read_preference == ReadPreference.SECONDARY + assert db.write_concern == write_concern - def test_getattr(self): - self.assertTrue(isinstance(self.client["_does_not_exist"], Database)) + def test_getattr(self, client): + assert isinstance(client["_does_not_exist"], Database) - with self.assertRaises(AttributeError) as context: - self.client._does_not_exist + with pytest.raises(AttributeError) as context: + client.client._does_not_exist # Message should be: # "AttributeError: MongoClient has no attribute '_does_not_exist'. To # access the _does_not_exist database, use client['_does_not_exist']". - self.assertIn("has no attribute '_does_not_exist'", str(context.exception)) + assert "has no attribute '_does_not_exist'" in str(context.value) - def test_iteration(self): - client = self.client + def test_iteration(self, client): msg = "'MongoClient' object is not iterable" - # Iteration fails - with self.assertRaisesRegex(TypeError, msg): - for _ in client: # type: ignore[misc] # error: "None" not callable [misc] + + with pytest.raises(TypeError, match=msg): + for _ in client: break + # Index fails - with self.assertRaises(TypeError): + with pytest.raises(TypeError): _ = client[0] - # next fails - with self.assertRaisesRegex(TypeError, "'MongoClient' object is not iterable"): + + # 'next' function fails + with pytest.raises(TypeError, match=msg): _ = next(client) - # .next() fails - with self.assertRaisesRegex(TypeError, "'MongoClient' object is not iterable"): + + # 'next()' method fails + with pytest.raises(TypeError, match=msg): _ = client.next() - # Do not implement typing.Iterable. - self.assertNotIsInstance(client, Iterable) - def test_get_default_database(self): - c = self.rs_or_single_client( - "mongodb://%s:%d/foo" % (client_context.host, client_context.port), + # Do not implement typing.Iterable + assert not isinstance(client, Iterable) + + def test_get_default_database(self, rs_or_single_client, client_context_fixture): + c = rs_or_single_client( + "mongodb://%s:%d/foo" % (client_context_fixture.host, client_context_fixture.port), connect=False, ) - self.assertEqual(Database(c, "foo"), c.get_default_database()) + assert Database(c, "foo") == c.get_default_database() # Test that default doesn't override the URI value. - self.assertEqual(Database(c, "foo"), c.get_default_database("bar")) - + assert Database(c, "foo") == c.get_default_database("bar") codec_options = CodecOptions(tz_aware=True) write_concern = WriteConcern(w=2, j=True) db = c.get_default_database(None, codec_options, ReadPreference.SECONDARY, write_concern) - self.assertEqual("foo", db.name) - self.assertEqual(codec_options, db.codec_options) - self.assertEqual(ReadPreference.SECONDARY, db.read_preference) - self.assertEqual(write_concern, db.write_concern) + assert "foo" == db.name + assert codec_options == db.codec_options + assert ReadPreference.SECONDARY == db.read_preference + assert write_concern == db.write_concern - c = self.rs_or_single_client( - "mongodb://%s:%d/" % (client_context.host, client_context.port), + c = rs_or_single_client( + "mongodb://%s:%d/" % (client_context_fixture.host, client_context_fixture.port), connect=False, ) - self.assertEqual(Database(c, "foo"), c.get_default_database("foo")) + assert Database(c, "foo") == c.get_default_database("foo") - def test_get_default_database_error(self): + def test_get_default_database_error(self, rs_or_single_client, client_context_fixture): # URI with no database. - c = self.rs_or_single_client( - "mongodb://%s:%d/" % (client_context.host, client_context.port), + c = rs_or_single_client( + "mongodb://%s:%d/" % (client_context_fixture.host, client_context_fixture.port), connect=False, ) - self.assertRaises(ConfigurationError, c.get_default_database) + with pytest.raises(ConfigurationError): + c.get_default_database() - def test_get_default_database_with_authsource(self): + def test_get_default_database_with_authsource( + self, client_context_fixture, rs_or_single_client + ): # Ensure we distinguish database name from authSource. uri = "mongodb://%s:%d/foo?authSource=src" % ( - client_context.host, - client_context.port, + client_context_fixture.host, + client_context_fixture.port, ) - c = self.rs_or_single_client(uri, connect=False) - self.assertEqual(Database(c, "foo"), c.get_default_database()) + c = rs_or_single_client(uri, connect=False) + assert Database(c, "foo") == c.get_default_database() - def test_get_database_default(self): - c = self.rs_or_single_client( - "mongodb://%s:%d/foo" % (client_context.host, client_context.port), + def test_get_database_default(self, client_context_fixture, rs_or_single_client): + c = rs_or_single_client( + "mongodb://%s:%d/foo" % (client_context_fixture.host, client_context_fixture.port), connect=False, ) - self.assertEqual(Database(c, "foo"), c.get_database()) + assert Database(c, "foo") == c.get_database() - def test_get_database_default_error(self): + def test_get_database_default_error(self, client_context_fixture, rs_or_single_client): # URI with no database. - c = self.rs_or_single_client( - "mongodb://%s:%d/" % (client_context.host, client_context.port), + c = rs_or_single_client( + "mongodb://%s:%d/" % (client_context_fixture.host, client_context_fixture.port), connect=False, ) - self.assertRaises(ConfigurationError, c.get_database) + with pytest.raises(ConfigurationError): + c.get_database() - def test_get_database_default_with_authsource(self): + def test_get_database_default_with_authsource( + self, client_context_fixture, rs_or_single_client + ): # Ensure we distinguish database name from authSource. uri = "mongodb://%s:%d/foo?authSource=src" % ( - client_context.host, - client_context.port, + client_context_fixture.host, + client_context_fixture.port, ) - c = self.rs_or_single_client(uri, connect=False) - self.assertEqual(Database(c, "foo"), c.get_database()) + c = rs_or_single_client(uri, connect=False) + assert Database(c, "foo") == c.get_database() - def test_primary_read_pref_with_tags(self): + def test_primary_read_pref_with_tags(self, single_client): # No tags allowed with "primary". - with self.assertRaises(ConfigurationError): - self.single_client("mongodb://host/?readpreferencetags=dc:east") - - with self.assertRaises(ConfigurationError): - self.single_client("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") + with pytest.raises(ConfigurationError): + with single_client("mongodb://host/?readpreferencetags=dc:east"): + pass + with pytest.raises(ConfigurationError): + with single_client("mongodb://host/?readpreference=primary&readpreferencetags=dc:east"): + pass - def test_read_preference(self): - c = self.rs_or_single_client( + def test_read_preference(self, client_context_fixture, rs_or_single_client): + c = rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode ) - self.assertEqual(c.read_preference, ReadPreference.NEAREST) + assert c.read_preference == ReadPreference.NEAREST - def test_metadata(self): + def test_metadata(self, simple_client): metadata = copy.deepcopy(_METADATA) if has_c(): metadata["driver"]["name"] = "PyMongo|c" else: metadata["driver"]["name"] = "PyMongo" metadata["application"] = {"name": "foobar"} - client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") + + client = simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options - self.assertEqual(options.pool_options.metadata, metadata) - client = self.simple_client("foo", 27017, appname="foobar", connect=False) + assert options.pool_options.metadata == metadata + + client = simple_client("foo", 27017, appname="foobar", connect=False) options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + assert options.pool_options.metadata == metadata + # No error - self.simple_client(appname="x" * 128) - with self.assertRaises(ValueError): - self.simple_client(appname="x" * 129) - # Bad "driver" options. - self.assertRaises(TypeError, DriverInfo, "Foo", 1, "a") - self.assertRaises(TypeError, DriverInfo, version="1", platform="a") - self.assertRaises(TypeError, DriverInfo) - with self.assertRaises(TypeError): - self.simple_client(driver=1) - with self.assertRaises(TypeError): - self.simple_client(driver="abc") - with self.assertRaises(TypeError): - self.simple_client(driver=("Foo", "1", "a")) - # Test appending to driver info. + simple_client(appname="x" * 128) + with pytest.raises(ValueError): + simple_client(appname="x" * 129) + + # Bad "driver" options. + with pytest.raises(TypeError): + DriverInfo("Foo", 1, "a") # type: ignore[arg-type] + with pytest.raises(TypeError): + DriverInfo(version="1", platform="a") # type: ignore[call-arg] + with pytest.raises(TypeError): + DriverInfo() # type: ignore[call-arg] + with pytest.raises(TypeError): + simple_client(driver=1) + with pytest.raises(TypeError): + simple_client(driver="abc") + with pytest.raises(TypeError): + simple_client(driver=("Foo", "1", "a")) + + # Test appending to driver info. if has_c(): metadata["driver"]["name"] = "PyMongo|c|FooDriver" else: metadata["driver"]["name"] = "PyMongo|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - client = self.simple_client( + + client = simple_client( "foo", 27017, appname="foobar", @@ -370,9 +395,10 @@ def test_metadata(self): connect=False, ) options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + assert options.pool_options.metadata == metadata + metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - client = self.simple_client( + client = simple_client( "foo", 27017, appname="foobar", @@ -380,38 +406,35 @@ def test_metadata(self): connect=False, ) options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + assert options.pool_options.metadata == metadata + # Test truncating driver info metadata. - client = self.simple_client( + client = simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, ) options = client.options - self.assertLessEqual( - len(bson.encode(options.pool_options.metadata)), - _MAX_METADATA_SIZE, - ) - client = self.simple_client( + assert len(bson.encode(options.pool_options.metadata)) <= _MAX_METADATA_SIZE + + client = simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, ) options = client.options - self.assertLessEqual( - len(bson.encode(options.pool_options.metadata)), - _MAX_METADATA_SIZE, - ) + assert len(bson.encode(options.pool_options.metadata)) <= _MAX_METADATA_SIZE - @mock.patch.dict("os.environ", {ENV_VAR_K8S: "1"}) - def test_container_metadata(self): - metadata = copy.deepcopy(_METADATA) - metadata["driver"]["name"] = "PyMongo" - metadata["env"] = {} - metadata["env"]["container"] = {"orchestrator": "kubernetes"} - client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") - options = client.options - self.assertEqual(options.pool_options.metadata["env"], metadata["env"]) + def test_container_metadata(self, simple_client): + with mock.patch("os.environ", {ENV_VAR_K8S: "1"}): + metadata = copy.deepcopy(_METADATA) + metadata["driver"]["name"] = "PyMongo" + metadata["env"] = {} + metadata["env"]["container"] = {"orchestrator": "kubernetes"} + + client = simple_client("mongodb://foo:27017/?appname=foobar&connect=false") + options = client.options + assert options.pool_options.metadata["env"] == metadata["env"] - def test_kwargs_codec_options(self): + def test_kwargs_codec_options(self, simple_client): class MyFloatType: def __init__(self, x): self.__x = x @@ -433,7 +456,7 @@ def transform_python(self, value): uuid_representation_label = "javaLegacy" unicode_decode_error_handler = "ignore" tzinfo = utc - c = self.simple_client( + c = simple_client( document_class=document_class, type_registry=type_registry, tz_aware=tz_aware, @@ -442,18 +465,16 @@ def transform_python(self, value): tzinfo=tzinfo, connect=False, ) - self.assertEqual(c.codec_options.document_class, document_class) - self.assertEqual(c.codec_options.type_registry, type_registry) - self.assertEqual(c.codec_options.tz_aware, tz_aware) - self.assertEqual( - c.codec_options.uuid_representation, - _UUID_REPRESENTATIONS[uuid_representation_label], + assert c.codec_options.document_class == document_class + assert c.codec_options.type_registry == type_registry + assert c.codec_options.tz_aware == tz_aware + assert ( + c.codec_options.uuid_representation == _UUID_REPRESENTATIONS[uuid_representation_label] ) - self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) - self.assertEqual(c.codec_options.tzinfo, tzinfo) + assert c.codec_options.unicode_decode_error_handler == unicode_decode_error_handler + assert c.codec_options.tzinfo == tzinfo - def test_uri_codec_options(self): - # Ensure codec options are passed in correctly + def test_uri_codec_options(self, client_context_fixture, simple_client): uuid_representation_label = "javaLegacy" unicode_decode_error_handler = "ignore" datetime_conversion = "DATETIME_CLAMP" @@ -462,57 +483,36 @@ def test_uri_codec_options(self): "%s&unicode_decode_error_handler=%s" "&datetime_conversion=%s" % ( - client_context.host, - client_context.port, + client_context_fixture.host, + client_context_fixture.port, uuid_representation_label, unicode_decode_error_handler, datetime_conversion, ) ) - c = self.simple_client(uri, connect=False) - self.assertEqual(c.codec_options.tz_aware, True) - self.assertEqual( - c.codec_options.uuid_representation, - _UUID_REPRESENTATIONS[uuid_representation_label], + c = simple_client(uri, connect=False) + assert c.codec_options.tz_aware is True + assert ( + c.codec_options.uuid_representation == _UUID_REPRESENTATIONS[uuid_representation_label] ) - self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) - self.assertEqual( - c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] - ) - + assert c.codec_options.unicode_decode_error_handler == unicode_decode_error_handler + assert c.codec_options.datetime_conversion == DatetimeConversion[datetime_conversion] # Change the passed datetime_conversion to a number and re-assert. uri = uri.replace(datetime_conversion, f"{int(DatetimeConversion[datetime_conversion])}") - c = self.simple_client(uri, connect=False) - self.assertEqual( - c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] - ) + c = simple_client(uri, connect=False) + assert c.codec_options.datetime_conversion == DatetimeConversion[datetime_conversion] - def test_uri_option_precedence(self): + def test_uri_option_precedence(self, simple_client): # Ensure kwarg options override connection string options. uri = "mongodb://localhost/?ssl=true&replicaSet=name&readPreference=primary" - c = self.simple_client( - uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred" - ) + c = simple_client(uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred") clopts = c.options opts = clopts._options + assert opts["tls"] is False + assert clopts.replica_set_name == "newname" + assert clopts.read_preference == ReadPreference.SECONDARY_PREFERRED - self.assertEqual(opts["tls"], False) - self.assertEqual(clopts.replica_set_name, "newname") - self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) - - def test_connection_timeout_ms_propagates_to_DNS_resolver(self): - # Patch the resolver. - from pymongo.srv_resolver import _resolve - - patched_resolver = FunctionCallRecorder(_resolve) - pymongo.srv_resolver._resolve = patched_resolver - - def reset_resolver(): - pymongo.srv_resolver._resolve = _resolve - - self.addCleanup(reset_resolver) - - # Setup. + def test_connection_timeout_ms_propagates_to_DNS_resolver(self, patch_resolver, simple_client): base_uri = "mongodb+srv://test5.test.build.10gen.cc" connectTimeoutMS = 5000 expected_kw_value = 5.0 @@ -520,10 +520,10 @@ def reset_resolver(): expected_uri_value = 6.0 def test_scenario(args, kwargs, expected_value): - patched_resolver.reset() - self.simple_client(*args, **kwargs) - for _, kw in patched_resolver.call_list(): - self.assertAlmostEqual(kw["lifetime"], expected_value) + patch_resolver.reset() + simple_client(*args, **kwargs) + for _, kw in patch_resolver.call_list(): + assert pytest.approx(kw["lifetime"], rel=1e-6) == expected_value # No timeout specified. test_scenario((base_uri,), {}, CONNECT_TIMEOUT) @@ -538,38 +538,38 @@ def test_scenario(args, kwargs, expected_value): # Timeout specified in both kwargs and connection string. test_scenario((uri_with_timeout,), kwarg, expected_kw_value) - def test_uri_security_options(self): + def test_uri_security_options(self, simple_client): # Ensure that we don't silently override security-related options. - with self.assertRaises(InvalidURI): - self.simple_client("mongodb://localhost/?ssl=true", tls=False, connect=False) + with pytest.raises(InvalidURI): + simple_client("mongodb://localhost/?ssl=true", tls=False, connect=False) # Matching SSL and TLS options should not cause errors. - c = self.simple_client("mongodb://localhost/?ssl=false", tls=False, connect=False) - self.assertEqual(c.options._options["tls"], False) + c = simple_client("mongodb://localhost/?ssl=false", tls=False, connect=False) + assert c.options._options["tls"] is False # Conflicting tlsInsecure options should raise an error. - with self.assertRaises(InvalidURI): - self.simple_client( + with pytest.raises(InvalidURI): + simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidHostnames=True, ) # Conflicting legacy tlsInsecure options should also raise an error. - with self.assertRaises(InvalidURI): - self.simple_client( + with pytest.raises(InvalidURI): + simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidCertificates=False, ) # Conflicting kwargs should raise InvalidURI - with self.assertRaises(InvalidURI): - self.simple_client(ssl=True, tls=False) + with pytest.raises(InvalidURI): + simple_client(ssl=True, tls=False) - def test_event_listeners(self): - c = self.simple_client(event_listeners=[], connect=False) - self.assertEqual(c.options.event_listeners, []) + def test_event_listeners(self, simple_client): + c = simple_client(event_listeners=[], connect=False) + assert c.options.event_listeners == [] listeners = [ event_loggers.CommandLogger(), event_loggers.HeartbeatLogger(), @@ -577,28 +577,30 @@ def test_event_listeners(self): event_loggers.TopologyLogger(), event_loggers.ConnectionPoolLogger(), ] - c = self.simple_client(event_listeners=listeners, connect=False) - self.assertEqual(c.options.event_listeners, listeners) - - def test_client_options(self): - c = self.simple_client(connect=False) - self.assertIsInstance(c.options, ClientOptions) - self.assertIsInstance(c.options.pool_options, PoolOptions) - self.assertEqual(c.options.server_selection_timeout, 30) - self.assertEqual(c.options.pool_options.max_idle_time_seconds, None) - self.assertIsInstance(c.options.retry_writes, bool) - self.assertIsInstance(c.options.retry_reads, bool) + c = simple_client(event_listeners=listeners, connect=False) + assert c.options.event_listeners == listeners + + def test_client_options(self, simple_client): + c = simple_client(connect=False) + assert isinstance(c.options, ClientOptions) + assert isinstance(c.options.pool_options, PoolOptions) + assert c.options.server_selection_timeout == 30 + assert c.options.pool_options.max_idle_time_seconds is None + assert isinstance(c.options.retry_writes, bool) + assert isinstance(c.options.retry_reads, bool) def test_validate_suggestion(self): """Validate kwargs in constructor.""" for typo in ["auth", "Auth", "AUTH"]: - expected = f"Unknown option: {typo}. Did you mean one of (authsource, authmechanism, authoidcallowedhosts) or maybe a camelCase version of one? Refer to docstring." + expected = ( + f"Unknown option: {typo}. Did you mean one of (authsource, authmechanism, " + f"authoidcallowedhosts) or maybe a camelCase version of one? Refer to docstring." + ) expected = re.escape(expected) - with self.assertRaisesRegex(ConfigurationError, expected): + with pytest.raises(ConfigurationError, match=expected): MongoClient(**{typo: "standard"}) # type: ignore[arg-type] - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") - def test_detected_environment_logging(self, mock_get_hosts): + def test_detected_environment_logging(self, caplog): normal_hosts = [ "normal.host.com", "host.cosmos.azure.com", @@ -609,42 +611,47 @@ def test_detected_environment_logging(self, mock_get_hosts): multi_host = ( "host.cosmos.azure.com,host.docdb.amazonaws.com,host.docdb-elastic.amazonaws.com" ) - with self.assertLogs("pymongo", level="INFO") as cm: - for host in normal_hosts: - MongoClient(host, connect=False) - for host in srv_hosts: - mock_get_hosts.return_value = [(host, 1)] - MongoClient(host, connect=False) - MongoClient(multi_host, connect=False) - logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"] - self.assertEqual(len(logs), 7) - - @patch("pymongo.srv_resolver._SrvResolver.get_hosts") - def test_detected_environment_warning(self, mock_get_hosts): - with self._caplog.at_level(logging.WARN): - normal_hosts = [ - "host.cosmos.azure.com", - "host.docdb.amazonaws.com", - "host.docdb-elastic.amazonaws.com", - ] - srv_hosts = ["mongodb+srv://:@" + s for s in normal_hosts] - multi_host = ( - "host.cosmos.azure.com,host.docdb.amazonaws.com,host.docdb-elastic.amazonaws.com" - ) - for host in normal_hosts: - with self.assertWarns(UserWarning): - self.simple_client(host) - for host in srv_hosts: - mock_get_hosts.return_value = [(host, 1)] - with self.assertWarns(UserWarning): - self.simple_client(host) - with self.assertWarns(UserWarning): - self.simple_client(multi_host) - - -class TestClient(IntegrationTest): + with caplog.at_level(logging.INFO, logger="pymongo"): + with mock.patch("pymongo.srv_resolver._SrvResolver.get_hosts") as mock_get_hosts: + for host in normal_hosts: + MongoClient(host, connect=False) + for host in srv_hosts: + mock_get_hosts.return_value = [(host, 1)] + MongoClient(host, connect=False) + MongoClient(multi_host, connect=False) + logs = [ + record.getMessage() + for record in caplog.records + if record.name == "pymongo.client" + ] + assert len(logs) == 7 + + def test_detected_environment_warning(self, caplog, simple_client): + normal_hosts = [ + "host.cosmos.azure.com", + "host.docdb.amazonaws.com", + "host.docdb-elastic.amazonaws.com", + ] + srv_hosts = ["mongodb+srv://:@" + s for s in normal_hosts] + multi_host = ( + "host.cosmos.azure.com,host.docdb.amazonaws.com,host.docdb-elastic.amazonaws.com" + ) + with caplog.at_level(logging.WARN, logger="pymongo"): + with mock.patch("pymongo.srv_resolver._SrvResolver.get_hosts") as mock_get_hosts: + with pytest.warns(UserWarning): + for host in normal_hosts: + simple_client(host) + for host in srv_hosts: + mock_get_hosts.return_value = [(host, 1)] + simple_client(host) + simple_client(multi_host) + + +@pytest.mark.usefixtures("require_integration") +@pytest.mark.integration +class TestClientIntegrationTest(PyMongoTestCasePyTest): def test_multiple_uris(self): - with self.assertRaises(ConfigurationError): + with pytest.raises(ConfigurationError): MongoClient( host=[ "mongodb+srv://cluster-a.abc12.mongodb.net", @@ -653,207 +660,208 @@ def test_multiple_uris(self): ] ) - def test_max_idle_time_reaper_default(self): + def test_max_idle_time_reaper_default(self, rs_or_single_client): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper doesn't remove connections when maxIdleTimeMS not set - client = self.rs_or_single_client() + client = rs_or_single_client() server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass - self.assertEqual(1, len(server._pool.conns)) - self.assertTrue(conn in server._pool.conns) + assert 1 == len(server._pool.conns) + assert conn in server._pool.conns - def test_max_idle_time_reaper_removes_stale_minPoolSize(self): + def test_max_idle_time_reaper_removes_stale_minPoolSize(self, rs_or_single_client): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper removes idle socket and replaces it with a new one - client = self.rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) + client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass # When the reaper runs at the same time as the get_socket, two # connections could be created and checked into the pool. - self.assertGreaterEqual(len(server._pool.conns), 1) + assert len(server._pool.conns) >= 1 wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") - def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): + def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self, rs_or_single_client): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper respects maxPoolSize when adding new connections. - client = self.rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) + client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass # When the reaper runs at the same time as the get_socket, # maxPoolSize=1 should prevent two connections from being created. - self.assertEqual(1, len(server._pool.conns)) + assert 1 == len(server._pool.conns) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") - def test_max_idle_time_reaper_removes_stale(self): + def test_max_idle_time_reaper_removes_stale(self, rs_or_single_client): with client_knobs(kill_cursor_frequency=0.1): - # Assert reaper has removed idle socket and NOT replaced it - client = self.rs_or_single_client(maxIdleTimeMS=500) + # Assert that the reaper has removed the idle socket and NOT replaced it. + client = rs_or_single_client(maxIdleTimeMS=500) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn_one: pass - # Assert that the pool does not close connections prematurely. + # Assert that the pool does not close connections prematurely time.sleep(0.300) with server._pool.checkout() as conn_two: pass - self.assertIs(conn_one, conn_two) + assert conn_one is conn_two wait_until( lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) - def test_min_pool_size(self): + def test_min_pool_size(self, rs_or_single_client): with client_knobs(kill_cursor_frequency=0.1): - client = self.rs_or_single_client() + client = rs_or_single_client() server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) - self.assertEqual(0, len(server._pool.conns)) + assert len(server._pool.conns) == 0 # Assert that pool started up at minPoolSize - client = self.rs_or_single_client(minPoolSize=10) + client = rs_or_single_client(minPoolSize=10) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) wait_until( lambda: len(server._pool.conns) == 10, "pool initialized with 10 connections", ) - - # Assert that if a socket is closed, a new one takes its place + # Assert that if a socket is closed, a new one takes its place. with server._pool.checkout() as conn: conn.close_conn(None) wait_until( lambda: len(server._pool.conns) == 10, "a closed socket gets replaced from the pool", ) - self.assertFalse(conn in server._pool.conns) + assert conn not in server._pool.conns - def test_max_idle_time_checkout(self): + def test_max_idle_time_checkout(self, rs_or_single_client): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): - client = self.rs_or_single_client(maxIdleTimeMS=500) + client = rs_or_single_client(maxIdleTimeMS=500) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass - self.assertEqual(1, len(server._pool.conns)) + assert len(server._pool.conns) == 1 time.sleep(1) # Sleep so that the socket becomes stale. - with server._pool.checkout() as new_con: - self.assertNotEqual(conn, new_con) - self.assertEqual(1, len(server._pool.conns)) - self.assertFalse(conn in server._pool.conns) - self.assertTrue(new_con in server._pool.conns) + with server._pool.checkout() as new_conn: + assert conn != new_conn + assert len(server._pool.conns) == 1 + assert conn not in server._pool.conns + assert new_conn in server._pool.conns # Test that connections are reused if maxIdleTimeMS is not set. - client = self.rs_or_single_client() + client = rs_or_single_client() server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass - self.assertEqual(1, len(server._pool.conns)) + assert len(server._pool.conns) == 1 time.sleep(1) - with server._pool.checkout() as new_con: - self.assertEqual(conn, new_con) - self.assertEqual(1, len(server._pool.conns)) + with server._pool.checkout() as new_conn: + assert conn == new_conn + assert len(server._pool.conns) == 1 - def test_constants(self): + def test_constants(self, client_context_fixture, simple_client): """This test uses MongoClient explicitly to make sure that host and port are not overloaded. """ - host, port = client_context.host, client_context.port - kwargs: dict = client_context.default_client_options.copy() - if client_context.auth_enabled: + host, port = ( + client_context_fixture.host, + client_context_fixture.port, + ) + kwargs: dict = client_context_fixture.default_client_options.copy() + if client_context_fixture.auth_enabled: kwargs["username"] = db_user kwargs["password"] = db_pwd # Set bad defaults. MongoClient.HOST = "somedomainthatdoesntexist.org" MongoClient.PORT = 123456789 - with self.assertRaises(AutoReconnect): - c = self.simple_client(serverSelectionTimeoutMS=10, **kwargs) + with pytest.raises(AutoReconnect): + c = simple_client(serverSelectionTimeoutMS=10, **kwargs) connected(c) - - c = self.simple_client(host, port, **kwargs) + c = simple_client(host, port, **kwargs) # Override the defaults. No error. connected(c) - # Set good defaults. MongoClient.HOST = host MongoClient.PORT = port - # No error. - c = self.simple_client(**kwargs) + c = simple_client(**kwargs) connected(c) - def test_init_disconnected(self): - host, port = client_context.host, client_context.port - c = self.rs_or_single_client(connect=False) + def test_init_disconnected(self, client_context_fixture, rs_or_single_client, simple_client): + host, port = ( + client_context_fixture.host, + client_context_fixture.port, + ) + c = rs_or_single_client(connect=False) # is_primary causes client to block until connected - self.assertIsInstance(c.is_primary, bool) - c = self.rs_or_single_client(connect=False) - self.assertIsInstance(c.is_mongos, bool) - c = self.rs_or_single_client(connect=False) - self.assertIsInstance(c.options.pool_options.max_pool_size, int) - self.assertIsInstance(c.nodes, frozenset) - - c = self.rs_or_single_client(connect=False) - self.assertEqual(c.codec_options, CodecOptions()) - c = self.rs_or_single_client(connect=False) - self.assertFalse(c.primary) - self.assertFalse(c.secondaries) - c = self.rs_or_single_client(connect=False) - self.assertIsInstance(c.topology_description, TopologyDescription) - self.assertEqual(c.topology_description, c._topology._description) - if client_context.is_rs: + assert isinstance(c.is_primary, bool) + c = rs_or_single_client(connect=False) + assert isinstance(c.is_mongos, bool) + c = rs_or_single_client(connect=False) + assert isinstance(c.options.pool_options.max_pool_size, int) + assert isinstance(c.nodes, frozenset) + + c = rs_or_single_client(connect=False) + assert c.codec_options == CodecOptions() + c = rs_or_single_client(connect=False) + assert not c.primary + assert not c.secondaries + c = rs_or_single_client(connect=False) + assert isinstance(c.topology_description, TopologyDescription) + assert c.topology_description == c._topology._description + if client_context_fixture.is_rs: # The primary's host and port are from the replica set config. - self.assertIsNotNone(c.address) + assert c.address is not None else: - self.assertEqual(c.address, (host, port)) - + assert c.address == (host, port) bad_host = "somedomainthatdoesntexist.org" - c = self.simple_client(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) - with self.assertRaises(ConnectionFailure): + c = simple_client(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + with pytest.raises(ConnectionFailure): c.pymongo_test.test.find_one() - def test_init_disconnected_with_auth(self): + def test_init_disconnected_with_auth(self, simple_client): uri = "mongodb://user:pass@somedomainthatdoesntexist" - c = self.simple_client(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) - with self.assertRaises(ConnectionFailure): + c = simple_client(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + with pytest.raises(ConnectionFailure): c.pymongo_test.test.find_one() - def test_equality(self): - seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = self.rs_or_single_client(seed, connect=False) - self.assertEqual(client_context.client, c) + def test_equality(self, client_context_fixture, rs_or_single_client, simple_client): + seed = "{}:{}".format(*list(client_context_fixture.client._topology_settings.seeds)[0]) + c = rs_or_single_client(seed, connect=False) + assert client_context_fixture.client == c # Explicitly test inequality - self.assertFalse(client_context.client != c) + assert not client_context_fixture.client != c - c = self.rs_or_single_client("invalid.com", connect=False) - self.assertNotEqual(client_context.client, c) - self.assertTrue(client_context.client != c) + c = rs_or_single_client("invalid.com", connect=False) + assert client_context_fixture.client != c + assert client_context_fixture.client != c - c1 = self.simple_client("a", connect=False) - c2 = self.simple_client("b", connect=False) + c1 = simple_client("a", connect=False) + c2 = simple_client("b", connect=False) # Seeds differ: - self.assertNotEqual(c1, c2) + assert c1 != c2 - c1 = self.simple_client(["a", "b", "c"], connect=False) - c2 = self.simple_client(["c", "a", "b"], connect=False) + c1 = simple_client(["a", "b", "c"], connect=False) + c2 = simple_client(["c", "a", "b"], connect=False) # Same seeds but out of order still compares equal: - self.assertEqual(c1, c2) - - def test_hashable(self): - seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = self.rs_or_single_client(seed, connect=False) - self.assertIn(c, {client_context.client}) - c = self.rs_or_single_client("invalid.com", connect=False) - self.assertNotIn(c, {client_context.client}) - - def test_host_w_port(self): - with self.assertRaises(ValueError): - host = client_context.host + assert c1 == c2 + + def test_hashable(self, client_context_fixture, rs_or_single_client): + seed = "{}:{}".format(*list(client_context_fixture.client._topology_settings.seeds)[0]) + c = rs_or_single_client(seed, connect=False) + assert c in {client_context_fixture.client} + c = rs_or_single_client("invalid.com", connect=False) + assert c not in {client_context_fixture.client} + + def test_host_w_port(self, client_context_fixture): + with pytest.raises(ValueError): + host = client_context_fixture.host connected( MongoClient( f"{host}:1234567", @@ -862,7 +870,7 @@ def test_host_w_port(self): ) ) - def test_repr(self): + def test_repr(self, simple_client): # Used to test 'eval' below. import bson @@ -872,19 +880,16 @@ def test_repr(self): connect=False, document_class=SON, ) - the_repr = repr(client) - self.assertIn("MongoClient(host=", the_repr) - self.assertIn("document_class=bson.son.SON, tz_aware=False, connect=False, ", the_repr) - self.assertIn("connecttimeoutms=12345", the_repr) - self.assertIn("replicaset='replset'", the_repr) - self.assertIn("w=1", the_repr) - self.assertIn("wtimeoutms=100", the_repr) - + assert "MongoClient(host=" in the_repr + assert "document_class=bson.son.SON, tz_aware=False, connect=False, " in the_repr + assert "connecttimeoutms=12345" in the_repr + assert "replicaset='replset'" in the_repr + assert "w=1" in the_repr + assert "wtimeoutms=100" in the_repr with eval(the_repr) as client_two: - self.assertEqual(client_two, client) - - client = self.simple_client( + assert client_two == client + client = simple_client( "localhost:27017,localhost:27018", replicaSet="replset", connectTimeoutMS=12345, @@ -894,91 +899,94 @@ def test_repr(self): connect=False, ) the_repr = repr(client) - self.assertIn("MongoClient(host=", the_repr) - self.assertIn("document_class=dict, tz_aware=False, connect=False, ", the_repr) - self.assertIn("connecttimeoutms=12345", the_repr) - self.assertIn("replicaset='replset'", the_repr) - self.assertIn("sockettimeoutms=None", the_repr) - self.assertIn("w=1", the_repr) - self.assertIn("wtimeoutms=100", the_repr) - + assert "MongoClient(host=" in the_repr + assert "document_class=dict, tz_aware=False, connect=False, " in the_repr + assert "connecttimeoutms=12345" in the_repr + assert "replicaset='replset'" in the_repr + assert "sockettimeoutms=None" in the_repr + assert "w=1" in the_repr + assert "wtimeoutms=100" in the_repr with eval(the_repr) as client_two: - self.assertEqual(client_two, client) + assert client_two == client - def test_getters(self): - wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes") + def test_getters(self, client_context_fixture): + wait_until( + lambda: client_context_fixture.nodes == client_context_fixture.client.nodes, + "find all nodes", + ) - def test_list_databases(self): - cmd_docs = (self.client.admin.command("listDatabases"))["databases"] - cursor = self.client.list_databases() - self.assertIsInstance(cursor, CommandCursor) + def test_list_databases(self, client_context_fixture, rs_or_single_client): + cmd_docs = (client_context_fixture.client.admin.command("listDatabases"))["databases"] + cursor = client_context_fixture.client.list_databases() + assert isinstance(cursor, CommandCursor) helper_docs = cursor.to_list() - self.assertTrue(len(helper_docs) > 0) - self.assertEqual(len(helper_docs), len(cmd_docs)) + assert len(helper_docs) > 0 + assert len(helper_docs) == len(cmd_docs) # PYTHON-3529 Some fields may change between calls, just compare names. for helper_doc, cmd_doc in zip(helper_docs, cmd_docs): - self.assertIs(type(helper_doc), dict) - self.assertEqual(helper_doc.keys(), cmd_doc.keys()) - client = self.rs_or_single_client(document_class=SON) - for doc in client.list_databases(): - self.assertIs(type(doc), dict) - - self.client.pymongo_test.test.insert_one({}) - cursor = self.client.list_databases(filter={"name": "admin"}) + assert isinstance(helper_doc, dict) + assert helper_doc.keys() == cmd_doc.keys() + + client_doc = rs_or_single_client(document_class=SON) + for doc in client_doc.list_databases(): + assert isinstance(doc, dict) + + client_context_fixture.client.pymongo_test.test.insert_one({}) + cursor = client_context_fixture.client.list_databases(filter={"name": "admin"}) docs = cursor.to_list() - self.assertEqual(1, len(docs)) - self.assertEqual(docs[0]["name"], "admin") + assert len(docs) == 1 + assert docs[0]["name"] == "admin" - cursor = self.client.list_databases(nameOnly=True) + cursor = client_context_fixture.client.list_databases(nameOnly=True) for doc in cursor: - self.assertEqual(["name"], list(doc)) + assert list(doc) == ["name"] - def test_list_database_names(self): - self.client.pymongo_test.test.insert_one({"dummy": "object"}) - self.client.pymongo_test_mike.test.insert_one({"dummy": "object"}) - cmd_docs = (self.client.admin.command("listDatabases"))["databases"] + def test_list_database_names(self, client_context_fixture): + client_context_fixture.client.pymongo_test.test.insert_one({"dummy": "object"}) + client_context_fixture.client.pymongo_test_mike.test.insert_one({"dummy": "object"}) + cmd_docs = (client_context_fixture.client.admin.command("listDatabases"))["databases"] cmd_names = [doc["name"] for doc in cmd_docs] - db_names = self.client.list_database_names() - self.assertTrue("pymongo_test" in db_names) - self.assertTrue("pymongo_test_mike" in db_names) - self.assertEqual(db_names, cmd_names) - - def test_drop_database(self): - with self.assertRaises(TypeError): - self.client.drop_database(5) # type: ignore[arg-type] - with self.assertRaises(TypeError): - self.client.drop_database(None) # type: ignore[arg-type] - - self.client.pymongo_test.test.insert_one({"dummy": "object"}) - self.client.pymongo_test2.test.insert_one({"dummy": "object"}) - dbs = self.client.list_database_names() - self.assertIn("pymongo_test", dbs) - self.assertIn("pymongo_test2", dbs) - self.client.drop_database("pymongo_test") - - if client_context.is_rs: - wc_client = self.rs_or_single_client(w=len(client_context.nodes) + 1) - with self.assertRaises(WriteConcernError): + db_names = client_context_fixture.client.list_database_names() + assert "pymongo_test" in db_names + assert "pymongo_test_mike" in db_names + assert db_names == cmd_names + + def test_drop_database(self, client_context_fixture, rs_or_single_client): + with pytest.raises(TypeError): + client_context_fixture.client.drop_database(5) # type: ignore[arg-type] + with pytest.raises(TypeError): + client_context_fixture.client.drop_database(None) # type: ignore[arg-type] + + client_context_fixture.client.pymongo_test.test.insert_one({"dummy": "object"}) + client_context_fixture.client.pymongo_test2.test.insert_one({"dummy": "object"}) + dbs = client_context_fixture.client.list_database_names() + assert "pymongo_test" in dbs + assert "pymongo_test2" in dbs + client_context_fixture.client.drop_database("pymongo_test") + + if client_context_fixture.is_rs: + wc_client = rs_or_single_client(w=len(client_context_fixture.nodes) + 1) + with pytest.raises(WriteConcernError): wc_client.drop_database("pymongo_test2") - self.client.drop_database(self.client.pymongo_test2) - dbs = self.client.list_database_names() - self.assertNotIn("pymongo_test", dbs) - self.assertNotIn("pymongo_test2", dbs) + client_context_fixture.client.drop_database(client_context_fixture.client.pymongo_test2) + dbs = client_context_fixture.client.list_database_names() + assert "pymongo_test" not in dbs + assert "pymongo_test2" not in dbs - def test_close(self): - test_client = self.rs_or_single_client() + def test_close(self, rs_or_single_client): + test_client = rs_or_single_client() coll = test_client.pymongo_test.bar test_client.close() - with self.assertRaises(InvalidOperation): + with pytest.raises(InvalidOperation): coll.count_documents({}) - def test_close_kills_cursors(self): + def test_close_kills_cursors(self, rs_or_single_client): if sys.platform.startswith("java"): # We can't figure out how to make this test reliable with Jython. raise SkipTest("Can't test with Jython") - test_client = self.rs_or_single_client() + test_client = rs_or_single_client() # Kill any cursors possibly queued up by previous tests. gc.collect() test_client._process_periodic_tasks() @@ -990,238 +998,250 @@ def test_close_kills_cursors(self): # Open a cursor and leave it open on the server. cursor = coll.find().batch_size(10) - self.assertTrue(bool(next(cursor))) - self.assertLess(cursor.retrieved, docs_inserted) + assert bool(next(cursor)) + assert cursor.retrieved < docs_inserted # Open a command cursor and leave it open on the server. cursor = coll.aggregate([], batchSize=10) - self.assertTrue(bool(next(cursor))) + assert bool(next(cursor)) del cursor # Required for PyPy, Jython and other Python implementations that # don't use reference counting garbage collection. gc.collect() # Close the client and ensure the topology is closed. - self.assertTrue(test_client._topology._opened) + assert test_client._topology._opened test_client.close() - self.assertFalse(test_client._topology._opened) - test_client = self.rs_or_single_client() + assert not test_client._topology._opened + test_client = rs_or_single_client() # The killCursors task should not need to re-open the topology. test_client._process_periodic_tasks() - self.assertTrue(test_client._topology._opened) + assert test_client._topology._opened - def test_close_stops_kill_cursors_thread(self): - client = self.rs_client() + def test_close_stops_kill_cursors_thread(self, rs_client): + client = rs_client() client.test.test.find_one() - self.assertFalse(client._kill_cursors_executor._stopped) + assert not client._kill_cursors_executor._stopped # Closing the client should stop the thread. client.close() - self.assertTrue(client._kill_cursors_executor._stopped) + assert client._kill_cursors_executor._stopped # Reusing the closed client should raise an InvalidOperation error. - with self.assertRaises(InvalidOperation): + with pytest.raises(InvalidOperation): client.admin.command("ping") # Thread is still stopped. - self.assertTrue(client._kill_cursors_executor._stopped) + assert client._kill_cursors_executor._stopped - def test_uri_connect_option(self): + def test_uri_connect_option(self, rs_client): # Ensure that topology is not opened if connect=False. - client = self.rs_client(connect=False) - self.assertFalse(client._topology._opened) + client = rs_client(connect=False) + assert not client._topology._opened # Ensure kill cursors thread has not been started. if _IS_SYNC: kc_thread = client._kill_cursors_executor._thread - self.assertFalse(kc_thread and kc_thread.is_alive()) + assert not (kc_thread and kc_thread.is_alive()) else: kc_task = client._kill_cursors_executor._task - self.assertFalse(kc_task and not kc_task.done()) + assert not (kc_task and not kc_task.done()) # Using the client should open topology and start the thread. client.admin.command("ping") - self.assertTrue(client._topology._opened) + assert client._topology._opened if _IS_SYNC: kc_thread = client._kill_cursors_executor._thread - self.assertTrue(kc_thread and kc_thread.is_alive()) + assert kc_thread and kc_thread.is_alive() else: kc_task = client._kill_cursors_executor._task - self.assertTrue(kc_task and not kc_task.done()) + assert kc_task and not kc_task.done() - def test_close_does_not_open_servers(self): - client = self.rs_client(connect=False) + def test_close_does_not_open_servers(self, rs_client): + client = rs_client(connect=False) topology = client._topology - self.assertEqual(topology._servers, {}) + assert topology._servers == {} client.close() - self.assertEqual(topology._servers, {}) + assert topology._servers == {} - def test_close_closes_sockets(self): - client = self.rs_client() + def test_close_closes_sockets(self, rs_client): + client = rs_client() client.test.test.find_one() topology = client._topology client.close() for server in topology._servers.values(): - self.assertFalse(server._pool.conns) - self.assertTrue(server._monitor._executor._stopped) - self.assertTrue(server._monitor._rtt_monitor._executor._stopped) - self.assertFalse(server._monitor._pool.conns) - self.assertFalse(server._monitor._rtt_monitor._pool.conns) + assert not server._pool.conns + assert server._monitor._executor._stopped + assert server._monitor._rtt_monitor._executor._stopped + assert not server._monitor._pool.conns + assert not server._monitor._rtt_monitor._pool.conns def test_bad_uri(self): - with self.assertRaises(InvalidURI): + with pytest.raises(InvalidURI): MongoClient("http://localhost") - @client_context.require_auth - @client_context.require_no_fips - def test_auth_from_uri(self): - host, port = client_context.host, client_context.port - client_context.create_user("admin", "admin", "pass") - self.addCleanup(client_context.drop_user, "admin", "admin") - self.addCleanup(remove_all_users, self.client.pymongo_test) + @pytest.mark.usefixtures("require_auth") + @pytest.mark.usefixtures("require_no_fips") + @pytest.mark.parametrize("remove_all_users_fixture", ["pymongo_test"], indirect=True) + @pytest.mark.parametrize("drop_user_fixture", [("admin", "admin")], indirect=True) + def test_auth_from_uri( + self, + client_context_fixture, + rs_or_single_client_noauth, + remove_all_users_fixture, + drop_user_fixture, + ): + host, port = ( + client_context_fixture.host, + client_context_fixture.port, + ) + client_context_fixture.create_user("admin", "admin", "pass") - client_context.create_user("pymongo_test", "user", "pass", roles=["userAdmin", "readWrite"]) + client_context_fixture.create_user( + "pymongo_test", "user", "pass", roles=["userAdmin", "readWrite"] + ) - with self.assertRaises(OperationFailure): - connected(self.rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port))) + with pytest.raises(OperationFailure): + connected(rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port))) # No error. - connected(self.rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))) + connected(rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))) # Wrong database. uri = "mongodb://admin:pass@%s:%d/pymongo_test" % (host, port) - with self.assertRaises(OperationFailure): - connected(self.rs_or_single_client_noauth(uri)) + with pytest.raises(OperationFailure): + connected(rs_or_single_client_noauth(uri)) # No error. connected( - self.rs_or_single_client_noauth("mongodb://user:pass@%s:%d/pymongo_test" % (host, port)) + rs_or_single_client_noauth("mongodb://user:pass@%s:%d/pymongo_test" % (host, port)) ) # Auth with lazy connection. ( - self.rs_or_single_client_noauth( + rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False ) ).pymongo_test.test.find_one() # Wrong password. - bad_client = self.rs_or_single_client_noauth( + bad_client = rs_or_single_client_noauth( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False ) - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): bad_client.pymongo_test.test.find_one() - @client_context.require_auth - def test_username_and_password(self): - client_context.create_user("admin", "ad min", "pa/ss") - self.addCleanup(client_context.drop_user, "admin", "ad min") + @pytest.mark.usefixtures("require_auth") + @pytest.mark.parametrize("drop_user_fixture", [("admin", "ad min")], indirect=True) + def test_username_and_password( + self, client_context_fixture, rs_or_single_client_noauth, drop_user_fixture + ): + client_context_fixture.create_user("admin", "ad min", "pa/ss") - c = self.rs_or_single_client_noauth(username="ad min", password="pa/ss") + c = rs_or_single_client_noauth(username="ad min", password="pa/ss") # Username and password aren't in strings that will likely be logged. - self.assertNotIn("ad min", repr(c)) - self.assertNotIn("ad min", str(c)) - self.assertNotIn("pa/ss", repr(c)) - self.assertNotIn("pa/ss", str(c)) + assert "ad min" not in repr(c) + assert "ad min" not in str(c) + assert "pa/ss" not in repr(c) + assert "pa/ss" not in str(c) # Auth succeeds. c.server_info() - with self.assertRaises(OperationFailure): - (self.rs_or_single_client_noauth(username="ad min", password="foo")).server_info() + with pytest.raises(OperationFailure): + (rs_or_single_client_noauth(username="ad min", password="foo")).server_info() - @client_context.require_auth - @client_context.require_no_fips - def test_lazy_auth_raises_operation_failure(self): - host = client_context.host - lazy_client = self.rs_or_single_client_noauth( + @pytest.mark.usefixtures("require_auth") + @pytest.mark.usefixtures("require_no_fips") + def test_lazy_auth_raises_operation_failure( + self, client_context_fixture, rs_or_single_client_noauth + ): + host = client_context_fixture.host + lazy_client = rs_or_single_client_noauth( f"mongodb://user:wrong@{host}/pymongo_test", connect=False ) assertRaisesExactly(OperationFailure, lazy_client.test.collection.find_one) - @client_context.require_no_tls - def test_unix_socket(self): + @pytest.mark.usefixtures("require_no_tls") + def test_unix_socket(self, client_context_fixture, rs_or_single_client, simple_client): if not hasattr(socket, "AF_UNIX"): - raise SkipTest("UNIX-sockets are not supported on this system") + pytest.skip("UNIX-sockets are not supported on this system") - mongodb_socket = "/tmp/mongodb-%d.sock" % (client_context.port,) - encoded_socket = "%2Ftmp%2F" + "mongodb-%d.sock" % (client_context.port,) + mongodb_socket = "/tmp/mongodb-%d.sock" % (client_context_fixture.port,) + encoded_socket = "%2Ftmp%2F" + "mongodb-%d.sock" % (client_context_fixture.port,) if not os.access(mongodb_socket, os.R_OK): - raise SkipTest("Socket file is not accessible") + pytest.skip("Socket file is not accessible") uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. - client = self.rs_or_single_client(uri) + client = rs_or_single_client(uri) client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = client.list_database_names() - self.assertTrue("pymongo_test" in dbs) + assert "pymongo_test" in dbs - self.assertTrue(mongodb_socket in repr(client)) + assert mongodb_socket in repr(client) # Confirm it fails with a missing socket. - with self.assertRaises(ConnectionFailure): - c = self.simple_client( - "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 - ) + with pytest.raises(ConnectionFailure): + c = simple_client("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100) connected(c) - def test_document_class(self): - c = self.client + def test_document_class(self, client_context_fixture, rs_or_single_client): + c = client_context_fixture.client db = c.pymongo_test db.test.insert_one({"x": 1}) - self.assertEqual(dict, c.codec_options.document_class) - self.assertTrue(isinstance(db.test.find_one(), dict)) - self.assertFalse(isinstance(db.test.find_one(), SON)) + assert dict == c.codec_options.document_class + assert isinstance(db.test.find_one(), dict) + assert not isinstance(db.test.find_one(), SON) - c = self.rs_or_single_client(document_class=SON) + c = rs_or_single_client(document_class=SON) db = c.pymongo_test - self.assertEqual(SON, c.codec_options.document_class) - self.assertTrue(isinstance(db.test.find_one(), SON)) + assert SON == c.codec_options.document_class + assert isinstance(db.test.find_one(), SON) - def test_timeouts(self): - client = self.rs_or_single_client( + def test_timeouts(self, rs_or_single_client): + client = rs_or_single_client( connectTimeoutMS=10500, socketTimeoutMS=10500, maxIdleTimeMS=10500, serverSelectionTimeoutMS=10500, ) - self.assertEqual(10.5, (get_pool(client)).opts.connect_timeout) - self.assertEqual(10.5, (get_pool(client)).opts.socket_timeout) - self.assertEqual(10.5, (get_pool(client)).opts.max_idle_time_seconds) - self.assertEqual(10.5, client.options.pool_options.max_idle_time_seconds) - self.assertEqual(10.5, client.options.server_selection_timeout) + assert 10.5 == (get_pool(client)).opts.connect_timeout + assert 10.5 == (get_pool(client)).opts.socket_timeout + assert 10.5 == (get_pool(client)).opts.max_idle_time_seconds + assert 10.5 == client.options.pool_options.max_idle_time_seconds + assert 10.5 == client.options.server_selection_timeout - def test_socket_timeout_ms_validation(self): - c = self.rs_or_single_client(socketTimeoutMS=10 * 1000) - self.assertEqual(10, (get_pool(c)).opts.socket_timeout) + def test_socket_timeout_ms_validation(self, rs_or_single_client): + c = rs_or_single_client(socketTimeoutMS=10 * 1000) + assert 10 == (get_pool(c)).opts.socket_timeout - c = connected(self.rs_or_single_client(socketTimeoutMS=None)) - self.assertEqual(None, (get_pool(c)).opts.socket_timeout) + c = connected(rs_or_single_client(socketTimeoutMS=None)) + assert (get_pool(c)).opts.socket_timeout is None - c = connected(self.rs_or_single_client(socketTimeoutMS=0)) - self.assertEqual(None, (get_pool(c)).opts.socket_timeout) + c = connected(rs_or_single_client(socketTimeoutMS=0)) + assert (get_pool(c)).opts.socket_timeout is None - with self.assertRaises(ValueError): - with self.rs_or_single_client(socketTimeoutMS=-1): + with pytest.raises(ValueError): + with rs_or_single_client(socketTimeoutMS=-1): pass - with self.assertRaises(ValueError): - with self.rs_or_single_client(socketTimeoutMS=1e10): + with pytest.raises(ValueError): + with rs_or_single_client(socketTimeoutMS=1e10): pass - with self.assertRaises(ValueError): - with self.rs_or_single_client(socketTimeoutMS="foo"): + with pytest.raises(ValueError): + with rs_or_single_client(socketTimeoutMS="foo"): pass - def test_socket_timeout(self): - no_timeout = self.client + def test_socket_timeout(self, client_context_fixture, rs_or_single_client): + no_timeout = client_context_fixture.client timeout_sec = 1 - timeout = self.rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) - self.addCleanup(timeout.close) + timeout = rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) no_timeout.pymongo_test.drop_collection("test") no_timeout.pymongo_test.test.insert_one({"x": 1}) @@ -1233,129 +1253,125 @@ def get_x(db): doc = next(db.test.find().where(where_func)) return doc["x"] - self.assertEqual(1, get_x(no_timeout.pymongo_test)) - with self.assertRaises(NetworkTimeout): + assert 1 == get_x(no_timeout.pymongo_test) + with pytest.raises(NetworkTimeout): get_x(timeout.pymongo_test) def test_server_selection_timeout(self): client = MongoClient(serverSelectionTimeoutMS=100, connect=False) - self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + pytest.approx(client.options.server_selection_timeout, 0.1) client.close() client = MongoClient(serverSelectionTimeoutMS=0, connect=False) - self.assertAlmostEqual(0, client.options.server_selection_timeout) + pytest.approx(client.options.server_selection_timeout, 0) - self.assertRaises(ValueError, MongoClient, serverSelectionTimeoutMS="foo", connect=False) - self.assertRaises(ValueError, MongoClient, serverSelectionTimeoutMS=-1, connect=False) - self.assertRaises( - ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False - ) + pytest.raises(ValueError, MongoClient, serverSelectionTimeoutMS="foo", connect=False) + pytest.raises(ValueError, MongoClient, serverSelectionTimeoutMS=-1, connect=False) + pytest.raises(ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False) client.close() client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False) - self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + pytest.approx(client.options.server_selection_timeout, 0.1) client.close() client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) - self.assertAlmostEqual(0, client.options.server_selection_timeout) + pytest.approx(client.options.server_selection_timeout, 0) client.close() # Test invalid timeout in URI ignored and set to default. client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) - self.assertAlmostEqual(30, client.options.server_selection_timeout) + pytest.approx(client.options.server_selection_timeout, 30) client.close() client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) - self.assertAlmostEqual(30, client.options.server_selection_timeout) + pytest.approx(client.options.server_selection_timeout, 30) - def test_waitQueueTimeoutMS(self): - client = self.rs_or_single_client(waitQueueTimeoutMS=2000) - self.assertEqual((get_pool(client)).opts.wait_queue_timeout, 2) + def test_waitQueueTimeoutMS(self, rs_or_single_client): + client = rs_or_single_client(waitQueueTimeoutMS=2000) + assert 2 == (get_pool(client)).opts.wait_queue_timeout - def test_socketKeepAlive(self): - pool = get_pool(self.client) + def test_socketKeepAlive(self, client_context_fixture): + pool = get_pool(client_context_fixture.client) with pool.checkout() as conn: keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) - self.assertTrue(keepalive) + assert keepalive @no_type_check - def test_tz_aware(self): - self.assertRaises(ValueError, MongoClient, tz_aware="foo") + def test_tz_aware(self, client_context_fixture, rs_or_single_client): + pytest.raises(ValueError, MongoClient, tz_aware="foo") - aware = self.rs_or_single_client(tz_aware=True) - self.addCleanup(aware.close) - naive = self.client + aware = rs_or_single_client(tz_aware=True) + naive = client_context_fixture.client aware.pymongo_test.drop_collection("test") now = datetime.datetime.now(tz=datetime.timezone.utc) aware.pymongo_test.test.insert_one({"x": now}) - self.assertEqual(None, (naive.pymongo_test.test.find_one())["x"].tzinfo) - self.assertEqual(utc, (aware.pymongo_test.test.find_one())["x"].tzinfo) - self.assertEqual( - (aware.pymongo_test.test.find_one())["x"].replace(tzinfo=None), - (naive.pymongo_test.test.find_one())["x"], - ) + assert (naive.pymongo_test.test.find_one())["x"].tzinfo is None + assert utc == (aware.pymongo_test.test.find_one())["x"].tzinfo + assert (aware.pymongo_test.test.find_one())["x"].replace(tzinfo=None) == ( + naive.pymongo_test.test.find_one() + )["x"] - @client_context.require_ipv6 - def test_ipv6(self): - if client_context.tls: + @pytest.mark.usefixtures("require_ipv6") + def test_ipv6(self, client_context_fixture, rs_or_single_client_noauth): + if client_context_fixture.tls: if not HAVE_IPADDRESS: - raise SkipTest("Need the ipaddress module to test with SSL") + pytest.skip("Need the ipaddress module to test with SSL") - if client_context.auth_enabled: + if client_context_fixture.auth_enabled: auth_str = f"{db_user}:{db_pwd}@" else: auth_str = "" - uri = "mongodb://%s[::1]:%d" % (auth_str, client_context.port) - if client_context.is_rs: - uri += "/?replicaSet=" + (client_context.replica_set_name or "") + uri = "mongodb://%s[::1]:%d" % (auth_str, client_context_fixture.port) + if client_context_fixture.is_rs: + uri += "/?replicaSet=" + (client_context_fixture.replica_set_name or "") - client = self.rs_or_single_client_noauth(uri) + client = rs_or_single_client_noauth(uri) client.pymongo_test.test.insert_one({"dummy": "object"}) client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) dbs = client.list_database_names() - self.assertTrue("pymongo_test" in dbs) - self.assertTrue("pymongo_test_bernie" in dbs) + assert "pymongo_test" in dbs + assert "pymongo_test_bernie" in dbs - def test_contextlib(self): - client = self.rs_or_single_client() + def test_contextlib(self, rs_or_single_client): + client = rs_or_single_client() client.pymongo_test.drop_collection("test") client.pymongo_test.test.insert_one({"foo": "bar"}) # The socket used for the previous commands has been returned to the # pool - self.assertEqual(1, len((get_pool(client)).conns)) + assert 1 == len((get_pool(client)).conns) # contextlib async support was added in Python 3.10 if _IS_SYNC or sys.version_info >= (3, 10): with contextlib.closing(client): - self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"]) - with self.assertRaises(InvalidOperation): + assert "bar" == (client.pymongo_test.test.find_one())["foo"] + with pytest.raises(InvalidOperation): client.pymongo_test.test.find_one() - client = self.rs_or_single_client() + client = rs_or_single_client() with client as client: - self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"]) - with self.assertRaises(InvalidOperation): + assert "bar" == (client.pymongo_test.test.find_one())["foo"] + with pytest.raises(InvalidOperation): client.pymongo_test.test.find_one() - @client_context.require_sync - def test_interrupt_signal(self): + @pytest.mark.usefixtures("require_sync") + def test_interrupt_signal(self, client_context_fixture): if sys.platform.startswith("java"): # We can't figure out how to raise an exception on a thread that's # blocked on a socket, whether that's the main thread or a worker, # without simply killing the whole thread in Jython. This suggests # PYTHON-294 can't actually occur in Jython. - raise SkipTest("Can't test interrupts in Jython") + pytest.skip("Can't test interrupts in Jython") if is_greenthread_patched(): - raise SkipTest("Can't reliably test interrupts with green threads") + pytest.skip("Can't reliably test interrupts with green threads") # Test fix for PYTHON-294 -- make sure MongoClient closes its # socket if it gets an interrupt while waiting to recv() from it. - db = self.client.pymongo_test + db = client_context_fixture.client.pymongo_test # A $where clause which takes 1.5 sec to execute where = delay(1.5) @@ -1397,48 +1413,48 @@ def sigalarm(num, frame): except KeyboardInterrupt: raised = True - # Can't use self.assertRaises() because it doesn't catch system - # exceptions - self.assertTrue(raised, "Didn't raise expected KeyboardInterrupt") + assert raised, "Didn't raise expected KeyboardInterrupt" # Raises AssertionError due to PYTHON-294 -- Mongo's response to # the previous find() is still waiting to be read on the socket, # so the request id's don't match. - self.assertEqual({"_id": 1}, next(db.foo.find())) # type: ignore[call-overload] + assert {"_id": 1} == next(db.foo.find()) # type: ignore[call-overload] finally: if old_signal_handler: signal.signal(signal.SIGALRM, old_signal_handler) - def test_operation_failure(self): + def test_operation_failure(self, single_client): # Ensure MongoClient doesn't close socket after it gets an error # response to getLastError. PYTHON-395. We need a new client here # to avoid race conditions caused by replica set failover or idle # socket reaping. - client = self.single_client() + client = single_client() client.pymongo_test.test.find_one() pool = get_pool(client) socket_count = len(pool.conns) - self.assertGreaterEqual(socket_count, 1) + assert socket_count >= 1 old_conn = next(iter(pool.conns)) client.pymongo_test.test.drop() client.pymongo_test.test.insert_one({"_id": "foo"}) - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): client.pymongo_test.test.insert_one({"_id": "foo"}) - self.assertEqual(socket_count, len(pool.conns)) - new_con = next(iter(pool.conns)) - self.assertEqual(old_conn, new_con) + assert socket_count == len(pool.conns) + new_conn = next(iter(pool.conns)) + assert old_conn == new_conn - def test_lazy_connect_w0(self): + @pytest.mark.parametrize("drop_database_fixture", ["test_lazy_connect_w0"], indirect=True) + def test_lazy_connect_w0( + self, client_context_fixture, rs_or_single_client, drop_database_fixture + ): # Ensure that connect-on-demand works when the first operation is # an unacknowledged write. This exercises _writable_max_wire_version(). # Use a separate collection to avoid races where we're still # completing an operation on a collection while the next test begins. - client_context.client.drop_database("test_lazy_connect_w0") - self.addCleanup(client_context.client.drop_database, "test_lazy_connect_w0") + client_context_fixture.client.drop_database("test_lazy_connect_w0") - client = self.rs_or_single_client(connect=False, w=0) + client = rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.insert_one({}) def predicate(): @@ -1446,7 +1462,7 @@ def predicate(): wait_until(predicate, "find one document") - client = self.rs_or_single_client(connect=False, w=0) + client = rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) def predicate(): @@ -1454,7 +1470,7 @@ def predicate(): wait_until(predicate, "update one document") - client = self.rs_or_single_client(connect=False, w=0) + client = rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.delete_one({}) def predicate(): @@ -1462,11 +1478,11 @@ def predicate(): wait_until(predicate, "delete one document") - @client_context.require_no_mongos - def test_exhaust_network_error(self): + @pytest.mark.usefixtures("require_no_mongos") + def test_exhaust_network_error(self, rs_or_single_client): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = self.rs_or_single_client(maxPoolSize=1, retryReads=False) + client = rs_or_single_client(maxPoolSize=1, retryReads=False) collection = client.pymongo_test.test pool = get_pool(client) pool._check_interval_seconds = None # Never check. @@ -1478,23 +1494,21 @@ def test_exhaust_network_error(self): conn = one(pool.conns) conn.conn.close() cursor = collection.find(cursor_type=CursorType.EXHAUST) - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): next(cursor) - self.assertTrue(conn.closed) + assert conn.closed # The semaphore was decremented despite the error. - self.assertEqual(0, pool.requests) + assert 0 == pool.requests - @client_context.require_auth - def test_auth_network_error(self): + @pytest.mark.usefixtures("require_auth") + def test_auth_network_error(self, rs_or_single_client): # Make sure there's no semaphore leak if we get a network error # when authenticating a new socket with cached credentials. # Get a client with one socket so we detect if it's leaked. - c = connected( - self.rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False) - ) + c = connected(rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False)) # Cause a network error on the actual socket. pool = get_pool(c) @@ -1503,25 +1517,25 @@ def test_auth_network_error(self): # Connection.authenticate logs, but gets a socket.error. Should be # reraised as AutoReconnect. - with self.assertRaises(AutoReconnect): + with pytest.raises(AutoReconnect): c.test.collection.find_one() # No semaphore leak, the pool is allowed to make a new socket. c.test.collection.find_one() - @client_context.require_no_replica_set - def test_connect_to_standalone_using_replica_set_name(self): - client = self.single_client(replicaSet="anything", serverSelectionTimeoutMS=100) - with self.assertRaises(AutoReconnect): + @pytest.mark.usefixtures("require_no_replica_set") + def test_connect_to_standalone_using_replica_set_name(self, single_client): + client = single_client(replicaSet="anything", serverSelectionTimeoutMS=100) + with pytest.raises(AutoReconnect): client.test.test.find_one() - @client_context.require_replica_set - def test_stale_getmore(self): + @pytest.mark.usefixtures("require_replica_set") + def test_stale_getmore(self, rs_client): # A cursor is created, but its member goes down and is removed from # the topology before the getMore message is sent. Test that # MongoClient._run_operation_with_response handles the error. - with self.assertRaises(AutoReconnect): - client = self.rs_client(connect=False, serverSelectionTimeoutMS=100) + with pytest.raises(AutoReconnect): + client = rs_client(connect=False, serverSelectionTimeoutMS=100) client._run_operation( operation=message._GetMore( "pymongo_test", @@ -1541,7 +1555,7 @@ def test_stale_getmore(self): address=("not-a-member", 27017), ) - def test_heartbeat_frequency_ms(self): + def test_heartbeat_frequency_ms(self, client_context_fixture, single_client): class HeartbeatStartedListener(ServerHeartbeatListener): def __init__(self): self.results = [] @@ -1566,116 +1580,117 @@ def init(self, *args): ServerHeartbeatStartedEvent.__init__ = init # type: ignore listener = HeartbeatStartedListener() uri = "mongodb://%s:%d/?heartbeatFrequencyMS=500" % ( - client_context.host, - client_context.port, + client_context_fixture.host, + client_context_fixture.port, ) - self.single_client(uri, event_listeners=[listener]) + single_client(uri, event_listeners=[listener]) wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) # Default heartbeatFrequencyMS is 10 sec. Check the interval was # closer to 0.5 sec with heartbeatFrequencyMS configured. - self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) + pytest.approx(heartbeat_times[1] - heartbeat_times[0], 0.5, abs=2) finally: ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore def test_small_heartbeat_frequency_ms(self): uri = "mongodb://example/?heartbeatFrequencyMS=499" - with self.assertRaises(ConfigurationError) as context: + with pytest.raises(ConfigurationError) as context: MongoClient(uri) - self.assertIn("heartbeatFrequencyMS", str(context.exception)) + assert "heartbeatFrequencyMS" in str(context.value) - def test_compression(self): + def test_compression(self, client_context_fixture, simple_client, single_client): def compression_settings(client): pool_options = client.options.pool_options return pool_options._compression_settings - uri = "mongodb://localhost:27017/?compressors=zlib" - client = self.simple_client(uri, connect=False) + client = simple_client("mongodb://localhost:27017/?compressors=zlib", connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" - client = self.simple_client(uri, connect=False) + assert opts.compressors == ["zlib"] + + client = simple_client( + "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, 4) - uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" - client = self.simple_client(uri, connect=False) + assert opts.compressors == ["zlib"] + assert opts.zlib_compression_level == 4 + + client = simple_client( + "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) - uri = "mongodb://localhost:27017" - client = self.simple_client(uri, connect=False) + assert opts.compressors == ["zlib"] + assert opts.zlib_compression_level == -1 + + client = simple_client("mongodb://localhost:27017", connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, []) - self.assertEqual(opts.zlib_compression_level, -1) - uri = "mongodb://localhost:27017/?compressors=foobar" - client = self.simple_client(uri, connect=False) + assert opts.compressors == [] + assert opts.zlib_compression_level == -1 + + client = simple_client("mongodb://localhost:27017/?compressors=foobar", connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, []) - self.assertEqual(opts.zlib_compression_level, -1) - uri = "mongodb://localhost:27017/?compressors=foobar,zlib" - client = self.simple_client(uri, connect=False) + assert opts.compressors == [] + assert opts.zlib_compression_level == -1 + + client = simple_client("mongodb://localhost:27017/?compressors=foobar,zlib", connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + assert opts.compressors == ["zlib"] + assert opts.zlib_compression_level == -1 - # According to the connection string spec, unsupported values - # just raise a warning and are ignored. - uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" - client = self.simple_client(uri, connect=False) + client = simple_client( + "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) - uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" - client = self.simple_client(uri, connect=False) + assert opts.compressors == ["zlib"] + assert opts.zlib_compression_level == -1 + + client = simple_client( + "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + assert opts.compressors == ["zlib"] + assert opts.zlib_compression_level == -1 if not _have_snappy(): - uri = "mongodb://localhost:27017/?compressors=snappy" - client = self.simple_client(uri, connect=False) + client = simple_client("mongodb://localhost:27017/?compressors=snappy", connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, []) + assert opts.compressors == [] else: - uri = "mongodb://localhost:27017/?compressors=snappy" - client = self.simple_client(uri, connect=False) + client = simple_client("mongodb://localhost:27017/?compressors=snappy", connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["snappy"]) - uri = "mongodb://localhost:27017/?compressors=snappy,zlib" - client = self.simple_client(uri, connect=False) + assert opts.compressors == ["snappy"] + client = simple_client( + "mongodb://localhost:27017/?compressors=snappy,zlib", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["snappy", "zlib"]) + assert opts.compressors == ["snappy", "zlib"] if not _have_zstd(): - uri = "mongodb://localhost:27017/?compressors=zstd" - client = self.simple_client(uri, connect=False) + client = simple_client("mongodb://localhost:27017/?compressors=zstd", connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, []) + assert opts.compressors == [] else: - uri = "mongodb://localhost:27017/?compressors=zstd" - client = self.simple_client(uri, connect=False) + client = simple_client("mongodb://localhost:27017/?compressors=zstd", connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zstd"]) - uri = "mongodb://localhost:27017/?compressors=zstd,zlib" - client = self.simple_client(uri, connect=False) + assert opts.compressors == ["zstd"] + client = simple_client( + "mongodb://localhost:27017/?compressors=zstd,zlib", connect=False + ) opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zstd", "zlib"]) + assert opts.compressors == ["zstd", "zlib"] - options = client_context.default_client_options + options = client_context_fixture.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): - client = self.single_client(zlibcompressionlevel=level) - # No error - client.pymongo_test.test.find_one() + client = single_client(zlibcompressionlevel=level) + client.pymongo_test.test.find_one() # No error - @client_context.require_sync - def test_reset_during_update_pool(self): - client = self.rs_or_single_client(minPoolSize=10) + @pytest.mark.usefixtures("require_sync") + def test_reset_during_update_pool(self, rs_or_single_client): + client = rs_or_single_client(minPoolSize=10) client.admin.command("ping") pool = get_pool(client) generation = pool.gen.get_overall() @@ -1703,8 +1718,7 @@ def run(self): t = ResetPoolThread(pool) t.start() - # Ensure that update_pool completes without error even when the pool - # is reset concurrently. + # Ensure that update_pool completes without error even when the pool is reset concurrently. try: while True: for _ in range(10): @@ -1716,15 +1730,14 @@ def run(self): t.join() client.admin.command("ping") - def test_background_connections_do_not_hold_locks(self): + def test_background_connections_do_not_hold_locks(self, rs_or_single_client): min_pool_size = 10 - client = self.rs_or_single_client( + client = rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False ) - # Create a single connection in the pool. - client.admin.command("ping") + client.admin.command("ping") # Create a single connection in the pool - # Cause new connections stall for a few seconds. + # Cause new connections to stall for a few seconds. pool = get_pool(client) original_connect = pool.connect @@ -1732,44 +1745,39 @@ def stall_connect(*args, **kwargs): time.sleep(2) return original_connect(*args, **kwargs) - pool.connect = stall_connect - # Un-patch Pool.connect to break the cyclic reference. - self.addCleanup(delattr, pool, "connect") - - # Wait for the background thread to start creating connections - wait_until(lambda: len(pool.conns) > 1, "start creating connections") + try: + pool.connect = stall_connect + + wait_until(lambda: len(pool.conns) > 1, "start creating connections") + # Assert that application operations do not block. + for _ in range(10): + start = time.monotonic() + client.admin.command("ping") + total = time.monotonic() - start + assert total < 2 + finally: + delattr(pool, "connect") - # Assert that application operations do not block. - for _ in range(10): - start = time.monotonic() - client.admin.command("ping") - total = time.monotonic() - start - # Each ping command should not take more than 2 seconds - self.assertLess(total, 2) - - @client_context.require_replica_set - def test_direct_connection(self): - # direct_connection=True should result in Single topology. - client = self.rs_or_single_client(directConnection=True) + @pytest.mark.usefixtures("require_replica_set") + def test_direct_connection(self, rs_or_single_client): + client = rs_or_single_client(directConnection=True) client.admin.command("ping") - self.assertEqual(len(client.nodes), 1) - self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) + assert len(client.nodes) == 1 + assert client._topology_settings.get_topology_type() == TOPOLOGY_TYPE.Single - # direct_connection=False should result in RS topology. - client = self.rs_or_single_client(directConnection=False) + client = rs_or_single_client(directConnection=False) client.admin.command("ping") - self.assertGreaterEqual(len(client.nodes), 1) - self.assertIn( - client._topology_settings.get_topology_type(), - [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary], - ) + assert len(client.nodes) >= 1 + assert client._topology_settings.get_topology_type() in [ + TOPOLOGY_TYPE.ReplicaSetNoPrimary, + TOPOLOGY_TYPE.ReplicaSetWithPrimary, + ] - # directConnection=True, should error with multiple hosts as a list. - with self.assertRaises(ConfigurationError): + with pytest.raises(ConfigurationError): MongoClient(["host1", "host2"], directConnection=True) - @unittest.skipIf("PyPy" in sys.version, "PYTHON-2927 fails often on PyPy") - def test_continuous_network_errors(self): + @pytest.mark.skipif("PyPy" in sys.version, reason="PYTHON-2927 fails often on PyPy") + def test_continuous_network_errors(self, simple_client): def server_description_count(): i = 0 for obj in gc.get_objects(): @@ -1782,50 +1790,43 @@ def server_description_count(): gc.collect() with client_knobs(min_heartbeat_interval=0.003): - client = self.simple_client( + client = simple_client( "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 ) initial_count = server_description_count() - with self.assertRaises(ServerSelectionTimeoutError): + with pytest.raises(ServerSelectionTimeoutError): client.test.test.find_one() gc.collect() final_count = server_description_count() - # If a bug like PYTHON-2433 is reintroduced then too many - # ServerDescriptions will be kept alive and this test will fail: - # AssertionError: 19 != 46 within 15 delta (27 difference) - # On Python 3.11 we seem to get more of a delta. - self.assertAlmostEqual(initial_count, final_count, delta=20) - - @client_context.require_failCommand_fail_point - def test_network_error_message(self): - client = self.single_client(retryReads=False) + assert pytest.approx(initial_count, abs=20) == final_count + + @pytest.mark.usefixtures("require_failCommand_fail_point") + def test_network_error_message(self, single_client): + client = single_client(retryReads=False) client.admin.command("ping") # connect with self.fail_point( - {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} + client, + {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}, ): assert client.address is not None expected = "{}:{}: ".format(*(client.address)) - with self.assertRaisesRegex(AutoReconnect, expected): + with pytest.raises(AutoReconnect, match=expected): client.pymongo_test.test.find_one({}) - @unittest.skipIf("PyPy" in sys.version, "PYTHON-2938 could fail on PyPy") - def test_process_periodic_tasks(self): - client = self.rs_or_single_client() + @pytest.mark.skipif("PyPy" in sys.version, reason="PYTHON-2938 could fail on PyPy") + def test_process_periodic_tasks(self, rs_or_single_client): + client = rs_or_single_client() coll = client.db.collection coll.insert_many([{} for _ in range(5)]) cursor = coll.find(batch_size=2) cursor.next() c_id = cursor.cursor_id - self.assertIsNotNone(c_id) + assert c_id is not None client.close() - # Add cursor to kill cursors queue del cursor - wait_until( - lambda: client._kill_cursors_queue, - "waited for cursor to be added to queue", - ) + wait_until(lambda: client._kill_cursors_queue, "waited for cursor to be added to queue") client._process_periodic_tasks() # This must not raise or print any exceptions - with self.assertRaises(InvalidOperation): + with pytest.raises(InvalidOperation): coll.insert_many([{} for _ in range(5)]) def test_service_name_from_kwargs(self): @@ -1834,82 +1835,79 @@ def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) - self.assertEqual(client._topology_settings.srv_service_name, "customname") + assert client._topology_settings.srv_service_name == "customname" + client = MongoClient( - "mongodb+srv://user:password@test22.test.build.10gen.cc" - "/?srvServiceName=shouldbeoverriden", + "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=shouldbeoverriden", srvServiceName="customname", connect=False, ) - self.assertEqual(client._topology_settings.srv_service_name, "customname") + assert client._topology_settings.srv_service_name == "customname" + client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", connect=False, ) - self.assertEqual(client._topology_settings.srv_service_name, "customname") - - def test_srv_max_hosts_kwarg(self): - client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") - self.assertGreater(len(client.topology_description.server_descriptions()), 1) - client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) - self.assertEqual(len(client.topology_description.server_descriptions()), 1) - client = self.simple_client( + assert client._topology_settings.srv_service_name == "customname" + + def test_srv_max_hosts_kwarg(self, simple_client): + client = simple_client("mongodb+srv://test1.test.build.10gen.cc/") + assert len(client.topology_description.server_descriptions()) > 1 + + client = simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + assert len(client.topology_description.server_descriptions()) == 1 + + client = simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) - self.assertEqual(len(client.topology_description.server_descriptions()), 2) - - @unittest.skipIf( - client_context.load_balancer or client_context.serverless, - "loadBalanced clients do not run SDAM", - ) - @unittest.skipIf(sys.platform == "win32", "Windows does not support SIGSTOP") - @client_context.require_sync - def test_sigstop_sigcont(self): + assert len(client.topology_description.server_descriptions()) == 2 + + @pytest.mark.skipif(sys.platform == "win32", reason="Windows does not support SIGSTOP") + @pytest.mark.usefixtures("require_sdam") + @pytest.mark.usefixtures("require_sync") + def test_sigstop_sigcont(self, client_context_fixture): test_dir = os.path.dirname(os.path.realpath(__file__)) script = os.path.join(test_dir, "sigstop_sigcont.py") - p = subprocess.Popen( - [sys.executable, script, client_context.uri], + with subprocess.Popen( + [sys.executable, script, client_context_fixture.uri], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - ) - self.addCleanup(p.wait, timeout=1) - self.addCleanup(p.kill) - time.sleep(1) - # Stop the child, sleep for twice the streaming timeout - # (heartbeatFrequencyMS + connectTimeoutMS), and restart. - os.kill(p.pid, signal.SIGSTOP) - time.sleep(2) - os.kill(p.pid, signal.SIGCONT) - time.sleep(0.5) - # Tell the script to exit gracefully. - outs, _ = p.communicate(input=b"q\n", timeout=10) - self.assertTrue(outs) - log_output = outs.decode("utf-8") - self.assertIn("TEST STARTED", log_output) - self.assertIn("ServerHeartbeatStartedEvent", log_output) - self.assertIn("ServerHeartbeatSucceededEvent", log_output) - self.assertIn("TEST COMPLETED", log_output) - self.assertNotIn("ServerHeartbeatFailedEvent", log_output) - - def _test_handshake(self, env_vars, expected_env): + ) as p: + time.sleep(1) + os.kill(p.pid, signal.SIGSTOP) + time.sleep(2) + os.kill(p.pid, signal.SIGCONT) + time.sleep(0.5) + outs, _ = p.communicate(input=b"q\n", timeout=10) + assert outs + log_output = outs.decode("utf-8") + assert "TEST STARTED" in log_output + assert "ServerHeartbeatStartedEvent" in log_output + assert "ServerHeartbeatSucceededEvent" in log_output + assert "TEST COMPLETED" in log_output + assert "ServerHeartbeatFailedEvent" not in log_output + + def _test_handshake(self, env_vars, expected_env, rs_or_single_client): with patch.dict("os.environ", env_vars): metadata = copy.deepcopy(_METADATA) if has_c(): metadata["driver"]["name"] = "PyMongo|c" else: metadata["driver"]["name"] = "PyMongo" + if expected_env is not None: metadata["env"] = expected_env if "AWS_REGION" not in env_vars: os.environ["AWS_REGION"] = "" - client = self.rs_or_single_client(serverSelectionTimeoutMS=10000) + + client = rs_or_single_client(serverSelectionTimeoutMS=10000) client.admin.command("ping") options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + assert options.pool_options.metadata == metadata - def test_handshake_01_aws(self): + def test_handshake_01_aws(self, rs_or_single_client): self._test_handshake( { "AWS_EXECUTION_ENV": "AWS_Lambda_python3.9", @@ -1917,12 +1915,18 @@ def test_handshake_01_aws(self): "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "1024", }, {"name": "aws.lambda", "region": "us-east-2", "memory_mb": 1024}, + rs_or_single_client, ) - def test_handshake_02_azure(self): - self._test_handshake({"FUNCTIONS_WORKER_RUNTIME": "python"}, {"name": "azure.func"}) + def test_handshake_02_azure(self, rs_or_single_client): + self._test_handshake( + {"FUNCTIONS_WORKER_RUNTIME": "python"}, + {"name": "azure.func"}, + rs_or_single_client, + ) - def test_handshake_03_gcp(self): + def test_handshake_03_gcp(self, rs_or_single_client): + # Regular case with environment variables. self._test_handshake( { "K_SERVICE": "servicename", @@ -1931,7 +1935,9 @@ def test_handshake_03_gcp(self): "FUNCTION_REGION": "us-central1", }, {"name": "gcp.func", "region": "us-central1", "memory_mb": 1024, "timeout_sec": 60}, + rs_or_single_client, ) + # Extra case for FUNCTION_NAME. self._test_handshake( { @@ -1941,45 +1947,50 @@ def test_handshake_03_gcp(self): "FUNCTION_REGION": "us-central1", }, {"name": "gcp.func", "region": "us-central1", "memory_mb": 1024, "timeout_sec": 60}, + rs_or_single_client, ) - def test_handshake_04_vercel(self): + def test_handshake_04_vercel(self, rs_or_single_client): self._test_handshake( - {"VERCEL": "1", "VERCEL_REGION": "cdg1"}, {"name": "vercel", "region": "cdg1"} + {"VERCEL": "1", "VERCEL_REGION": "cdg1"}, + {"name": "vercel", "region": "cdg1"}, + rs_or_single_client, ) - def test_handshake_05_multiple(self): + def test_handshake_05_multiple(self, rs_or_single_client): self._test_handshake( {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9", "FUNCTIONS_WORKER_RUNTIME": "python"}, None, + rs_or_single_client, ) - # Extra cases for other combos. + self._test_handshake( {"FUNCTIONS_WORKER_RUNTIME": "python", "K_SERVICE": "servicename"}, None, + rs_or_single_client, ) - self._test_handshake({"K_SERVICE": "servicename", "VERCEL": "1"}, None) - def test_handshake_06_region_too_long(self): + self._test_handshake({"K_SERVICE": "servicename", "VERCEL": "1"}, None, rs_or_single_client) + + def test_handshake_06_region_too_long(self, rs_or_single_client): self._test_handshake( {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9", "AWS_REGION": "a" * 512}, {"name": "aws.lambda"}, + rs_or_single_client, ) - def test_handshake_07_memory_invalid_int(self): + def test_handshake_07_memory_invalid_int(self, rs_or_single_client): self._test_handshake( {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9", "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "big"}, {"name": "aws.lambda"}, + rs_or_single_client, ) - def test_handshake_08_invalid_aws_ec2(self): + def test_handshake_08_invalid_aws_ec2(self, rs_or_single_client): # AWS_EXECUTION_ENV needs to start with "AWS_Lambda_". - self._test_handshake( - {"AWS_EXECUTION_ENV": "EC2"}, - None, - ) + self._test_handshake({"AWS_EXECUTION_ENV": "EC2"}, None, rs_or_single_client) - def test_handshake_09_container_with_provider(self): + def test_handshake_09_container_with_provider(self, rs_or_single_client): self._test_handshake( { ENV_VAR_K8S: "1", @@ -1993,102 +2004,96 @@ def test_handshake_09_container_with_provider(self): "region": "us-east-1", "memory_mb": 256, }, + rs_or_single_client, ) - def test_dict_hints(self): - self.db.t.find(hint={"x": 1}) + def test_dict_hints(self, client_context_fixture): + client_context_fixture.client.db.t.find(hint={"x": 1}) - def test_dict_hints_sort(self): - result = self.db.t.find() + def test_dict_hints_sort(self, client_context_fixture): + result = client_context_fixture.client.db.t.find() result.sort({"x": 1}) + client_context_fixture.client.db.t.find(sort={"x": 1}) - self.db.t.find(sort={"x": 1}) - - def test_dict_hints_create_index(self): - self.db.t.create_index({"x": pymongo.ASCENDING}) + def test_dict_hints_create_index(self, client_context_fixture): + client_context_fixture.client.db.t.create_index({"x": pymongo.ASCENDING}) - def test_legacy_java_uuid_roundtrip(self): + def test_legacy_java_uuid_roundtrip(self, client_context_fixture): data = BinaryData.java_data docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, JAVA_LEGACY)) - client_context.client.pymongo_test.drop_collection("java_uuid") - db = client_context.client.pymongo_test + client_context_fixture.client.pymongo_test.drop_collection("java_uuid") + db = client_context_fixture.client.pymongo_test coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY)) coll.insert_many(docs) - self.assertEqual(5, coll.count_documents({})) + assert coll.count_documents({}) == 5 for d in coll.find(): - self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) + assert d["newguid"] == uuid.UUID(d["newguidstring"]) coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) for d in coll.find(): - self.assertNotEqual(d["newguid"], d["newguidstring"]) - client_context.client.pymongo_test.drop_collection("java_uuid") + assert d["newguid"] != d["newguidstring"] + client_context_fixture.client.pymongo_test.drop_collection("java_uuid") - def test_legacy_csharp_uuid_roundtrip(self): + def test_legacy_csharp_uuid_roundtrip(self, client_context_fixture): data = BinaryData.csharp_data docs = bson.decode_all(data, CodecOptions(SON[str, Any], False, CSHARP_LEGACY)) - client_context.client.pymongo_test.drop_collection("csharp_uuid") - db = client_context.client.pymongo_test + client_context_fixture.client.pymongo_test.drop_collection("csharp_uuid") + db = client_context_fixture.client.pymongo_test coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY)) coll.insert_many(docs) - self.assertEqual(5, coll.count_documents({})) + assert coll.count_documents({}) == 5 for d in coll.find(): - self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) + assert d["newguid"] == uuid.UUID(d["newguidstring"]) coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) for d in coll.find(): - self.assertNotEqual(d["newguid"], d["newguidstring"]) - client_context.client.pymongo_test.drop_collection("csharp_uuid") + assert d["newguid"] != d["newguidstring"] + client_context_fixture.client.pymongo_test.drop_collection("csharp_uuid") - def test_uri_to_uuid(self): + def test_uri_to_uuid(self, single_client): uri = "mongodb://foo/?uuidrepresentation=csharpLegacy" - client = self.single_client(uri, connect=False) - self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY) + client = single_client(uri, connect=False) + assert client.pymongo_test.test.codec_options.uuid_representation == CSHARP_LEGACY - def test_uuid_queries(self): - db = client_context.client.pymongo_test + def test_uuid_queries(self, client_context_fixture): + db = client_context_fixture.client.pymongo_test coll = db.test coll.drop() uu = uuid.uuid4() coll.insert_one({"uuid": Binary(uu.bytes, 3)}) - self.assertEqual(1, coll.count_documents({})) + assert coll.count_documents({}) == 1 - # Test regular UUID queries (using subtype 4). coll = db.get_collection( "test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD) ) - self.assertEqual(0, coll.count_documents({"uuid": uu})) + assert coll.count_documents({"uuid": uu}) == 0 coll.insert_one({"uuid": uu}) - self.assertEqual(2, coll.count_documents({})) - docs = coll.find({"uuid": uu}).to_list() - self.assertEqual(1, len(docs)) - self.assertEqual(uu, docs[0]["uuid"]) + assert coll.count_documents({}) == 2 + docs = coll.find({"uuid": uu}).to_list(length=1) + assert len(docs) == 1 + assert docs[0]["uuid"] == uu - # Test both. uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY) predicate = {"uuid": {"$in": [uu, uu_legacy]}} - self.assertEqual(2, coll.count_documents(predicate)) - docs = coll.find(predicate).to_list() - self.assertEqual(2, len(docs)) + assert coll.count_documents(predicate) == 2 + docs = coll.find(predicate).to_list(length=2) + assert len(docs) == 2 coll.drop() -class TestExhaustCursor(IntegrationTest): - """Test that clients properly handle errors from exhaust cursors.""" - - def setUp(self): - super().setUp() - if client_context.is_mongos: - raise SkipTest("mongos doesn't support exhaust, SERVER-2627") - - def test_exhaust_query_server_error(self): +@pytest.mark.usefixtures("require_no_mongos") +@pytest.mark.usefixtures("require_integration") +@pytest.mark.integration +class TestExhaustCursor(PyMongoTestCasePyTest): + def test_exhaust_query_server_error(self, rs_or_single_client): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = connected(self.rs_or_single_client(maxPoolSize=1)) + client = connected(rs_or_single_client(maxPoolSize=1)) collection = client.pymongo_test.test pool = get_pool(client) @@ -2100,23 +2105,22 @@ def test_exhaust_query_server_error(self): SON([("$query", {}), ("$orderby", True)]), cursor_type=CursorType.EXHAUST ) - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): cursor.next() - self.assertFalse(conn.closed) + assert not conn.closed # The socket was checked in and the semaphore was decremented. - self.assertIn(conn, pool.conns) - self.assertEqual(0, pool.requests) + assert conn in pool.conns + assert pool.requests == 0 - def test_exhaust_getmore_server_error(self): + def test_exhaust_getmore_server_error(self, rs_or_single_client): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = self.rs_or_single_client(maxPoolSize=1) + client = rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test collection.drop() collection.insert_many([{} for _ in range(200)]) - self.addCleanup(client_context.client.pymongo_test.test.drop) pool = get_pool(client) pool._check_interval_seconds = None # Never check. @@ -2138,19 +2142,19 @@ def receive_message(request_id): return message._OpReply.unpack(msg) conn.receive_message = receive_message - with self.assertRaises(OperationFailure): + with pytest.raises(OperationFailure): cursor.to_list() # Unpatch the instance. del conn.receive_message # The socket is returned to the pool and it still works. - self.assertEqual(200, collection.count_documents({})) - self.assertIn(conn, pool.conns) + assert 200 == collection.count_documents({}) + assert conn in pool.conns - def test_exhaust_query_network_error(self): + def test_exhaust_query_network_error(self, rs_or_single_client): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = connected(self.rs_or_single_client(maxPoolSize=1, retryReads=False)) + client = connected(rs_or_single_client(maxPoolSize=1, retryReads=False)) collection = client.pymongo_test.test pool = get_pool(client) pool._check_interval_seconds = None # Never check. @@ -2160,18 +2164,18 @@ def test_exhaust_query_network_error(self): conn.conn.close() cursor = collection.find(cursor_type=CursorType.EXHAUST) - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): cursor.next() - self.assertTrue(conn.closed) + assert conn.closed # The socket was closed and the semaphore was decremented. - self.assertNotIn(conn, pool.conns) - self.assertEqual(0, pool.requests) + assert conn not in pool.conns + assert 0 == pool.requests - def test_exhaust_getmore_network_error(self): + def test_exhaust_getmore_network_error(self, rs_or_single_client): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = self.rs_or_single_client(maxPoolSize=1) + client = rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test collection.drop() collection.insert_many([{} for _ in range(200)]) # More than one batch. @@ -2188,39 +2192,39 @@ def test_exhaust_getmore_network_error(self): conn.conn.close() # A getmore fails. - with self.assertRaises(ConnectionFailure): + with pytest.raises(ConnectionFailure): cursor.to_list() - self.assertTrue(conn.closed) + assert conn.closed wait_until( lambda: len(client._kill_cursors_queue) == 0, "waited for all killCursor requests to complete", ) # The socket was closed and the semaphore was decremented. - self.assertNotIn(conn, pool.conns) - self.assertEqual(0, pool.requests) + assert conn not in pool.conns + assert 0 == pool.requests - @client_context.require_sync - def test_gevent_task(self): + @pytest.mark.usefixtures("require_sync") + def test_gevent_task(self, client_context_fixture): if not gevent_monkey_patched(): - raise SkipTest("Must be running monkey patched by gevent") + pytest.skip("Must be running monkey patched by gevent") from gevent import spawn def poller(): while True: - client_context.client.pymongo_test.test.insert_one({}) + client_context_fixture.client.pymongo_test.test.insert_one({}) task = spawn(poller) task.kill() - self.assertTrue(task.dead) + assert task.dead - @client_context.require_sync - def test_gevent_timeout(self): + @pytest.mark.usefixtures("require_sync") + def test_gevent_timeout(self, rs_or_single_client): if not gevent_monkey_patched(): - raise SkipTest("Must be running monkey patched by gevent") + pytest.skip("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = self.rs_or_single_client(maxPoolSize=1) + client = rs_or_single_client(maxPoolSize=1) coll = client.pymongo_test.test coll.insert_one({}) @@ -2241,19 +2245,19 @@ def timeout_task(): tt = spawn(timeout_task) tt.join(15) ct.join(15) - self.assertTrue(tt.dead) - self.assertTrue(ct.dead) - self.assertIsNone(tt.get()) - self.assertIsNone(ct.get()) + assert tt.dead + assert ct.dead + assert tt.get() is None + assert ct.get() is None - @client_context.require_sync - def test_gevent_timeout_when_creating_connection(self): + @pytest.mark.usefixtures("require_sync") + def test_gevent_timeout_when_creating_connection(self, rs_or_single_client): if not gevent_monkey_patched(): - raise SkipTest("Must be running monkey patched by gevent") + pytest.skip("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = self.rs_or_single_client() - self.addCleanup(client.close) + client = rs_or_single_client() + coll = client.pymongo_test.test pool = get_pool(client) @@ -2276,23 +2280,35 @@ def timeout_task(): tt.join(10) # Assert that we got our active_sockets count back - self.assertEqual(pool.active_sockets, 0) + assert pool.active_sockets == 0 # Assert the greenlet is dead - self.assertTrue(tt.dead) + assert tt.dead # Assert that the Timeout was raised all the way to the try - self.assertTrue(tt.get()) + assert tt.get() # Unpatch the instance. del pool.connect -class TestClientLazyConnect(IntegrationTest): +@pytest.mark.usefixtures("require_sync") +@pytest.mark.usefixtures("require_integration") +@pytest.mark.integration +class TestClientLazyConnect: """Test concurrent operations on a lazily-connecting MongoClient.""" - def _get_client(self): - return self.rs_or_single_client(connect=False) + @pytest.fixture + def _get_client(self, rs_or_single_client): + clients = [] + + def _make_client(): + client = rs_or_single_client(connect=False) + clients.append(client) + return client - @client_context.require_sync - def test_insert_one(self): + yield _make_client + for client in clients: + client.close() + + def test_insert_one(self, _get_client, client_context_fixture): def reset(collection): collection.drop() @@ -2300,12 +2316,11 @@ def insert_one(collection, _): collection.insert_one({}) def test(collection): - self.assertEqual(NTHREADS, collection.count_documents({})) + assert NTHREADS == collection.count_documents({}) - lazy_client_trial(reset, insert_one, test, self._get_client) + lazy_client_trial(reset, insert_one, test, _get_client, client_context_fixture) - @client_context.require_sync - def test_update_one(self): + def test_update_one(self, _get_client, client_context_fixture): def reset(collection): collection.drop() collection.insert_one({"i": 0}) @@ -2315,12 +2330,11 @@ def update_one(collection, _): collection.update_one({}, {"$inc": {"i": 1}}) def test(collection): - self.assertEqual(NTHREADS, collection.find_one()["i"]) + assert NTHREADS == collection.find_one()["i"] - lazy_client_trial(reset, update_one, test, self._get_client) + lazy_client_trial(reset, update_one, test, _get_client, client_context_fixture) - @client_context.require_sync - def test_delete_one(self): + def test_delete_one(self, _get_client, client_context_fixture): def reset(collection): collection.drop() collection.insert_many([{"i": i} for i in range(NTHREADS)]) @@ -2329,12 +2343,11 @@ def delete_one(collection, i): collection.delete_one({"i": i}) def test(collection): - self.assertEqual(0, collection.count_documents({})) + assert 0 == collection.count_documents({}) - lazy_client_trial(reset, delete_one, test, self._get_client) + lazy_client_trial(reset, delete_one, test, _get_client, client_context_fixture) - @client_context.require_sync - def test_find_one(self): + def test_find_one(self, _get_client, client_context_fixture): results: list = [] def reset(collection): @@ -2346,14 +2359,23 @@ def find_one(collection, _): results.append(collection.find_one()) def test(collection): - self.assertEqual(NTHREADS, len(results)) + assert NTHREADS == len(results) + + lazy_client_trial(reset, find_one, test, _get_client, client_context_fixture) - lazy_client_trial(reset, find_one, test, self._get_client) +@pytest.mark.usefixtures("require_no_load_balancer") +@pytest.mark.unit +class TestMongoClientFailover: + @pytest.fixture(scope="class", autouse=True) + def _client_knobs(self): + knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) + knobs.enable() + yield knobs + knobs.disable() -class TestMongoClientFailover(MockClientTest): - def test_discover_primary(self): - c = MockClient.get_mock_client( + def test_discover_primary(self, mock_client): + c = mock_client( standalones=[], members=["a:1", "b:2", "c:3"], mongoses=[], @@ -2361,11 +2383,10 @@ def test_discover_primary(self): replicaSet="rs", heartbeatFrequencyMS=500, ) - self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") - self.assertEqual(c.address, ("a", 1)) + assert c.address == ("a", 1) # Fail over. c.kill_host("a:1") c.mock_primary = "b:2" @@ -2375,11 +2396,11 @@ def predicate(): wait_until(predicate, "wait for server address to be updated") # a:1 not longer in nodes. - self.assertLess(len(c.nodes), 3) + assert len(c.nodes) < 3 - def test_reconnect(self): + def test_reconnect(self, mock_client): # Verify the node list isn't forgotten during a network failure. - c = MockClient.get_mock_client( + c = mock_client( standalones=[], members=["a:1", "b:2", "c:3"], mongoses=[], @@ -2388,7 +2409,6 @@ def test_reconnect(self): retryReads=False, serverSelectionTimeoutMS=1000, ) - self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2400,7 +2420,7 @@ def test_reconnect(self): # MongoClient discovers it's alone. The first attempt raises either # ServerSelectionTimeoutError or AutoReconnect (from # AsyncMockPool.get_socket). - with self.assertRaises(AutoReconnect): + with pytest.raises(AutoReconnect): c.db.collection.find_one() # But it can reconnect. @@ -2408,14 +2428,14 @@ def test_reconnect(self): (c._get_topology()).select_servers( writable_server_selector, _Op.TEST, server_selection_timeout=10 ) - self.assertEqual(c.address, ("a", 1)) + assert c.address == ("a", 1) - def _test_network_error(self, operation_callback): + def _test_network_error(self, mock_client, operation_callback): # Verify only the disconnected server is reset by a network failure. # Disable background refresh. with client_knobs(heartbeat_frequency=999999): - c = MockClient( + c = mock_client( standalones=[], members=["a:1", "b:2"], mongoses=[], @@ -2426,8 +2446,6 @@ def _test_network_error(self, operation_callback): serverSelectionTimeoutMS=1000, ) - self.addCleanup(c.close) - # Set host-specific information so we can test whether it is reset. c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION) c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1) @@ -2439,59 +2457,60 @@ def _test_network_error(self, operation_callback): # MongoClient is disconnected from the primary. This raises either # ServerSelectionTimeoutError or AutoReconnect (from # MockPool.get_socket). - with self.assertRaises(AutoReconnect): + with pytest.raises(AutoReconnect): operation_callback(c) # The primary's description is reset. server_a = (c._get_topology()).get_server_by_address(("a", 1)) sd_a = server_a.description - self.assertEqual(SERVER_TYPE.Unknown, sd_a.server_type) - self.assertEqual(0, sd_a.min_wire_version) - self.assertEqual(0, sd_a.max_wire_version) + assert SERVER_TYPE.Unknown == sd_a.server_type + assert 0 == sd_a.min_wire_version + assert 0 == sd_a.max_wire_version # ...but not the secondary's. server_b = (c._get_topology()).get_server_by_address(("b", 2)) sd_b = server_b.description - self.assertEqual(SERVER_TYPE.RSSecondary, sd_b.server_type) - self.assertEqual(2, sd_b.min_wire_version) - self.assertEqual(MIN_SUPPORTED_WIRE_VERSION + 1, sd_b.max_wire_version) + assert sd_b.server_type == SERVER_TYPE.RSSecondary + assert sd_b.min_wire_version == 2 + assert sd_b.max_wire_version == MIN_SUPPORTED_WIRE_VERSION + 1 - def test_network_error_on_query(self): + def test_network_error_on_query(self, mock_client): def callback(client): return client.db.collection.find_one() - self._test_network_error(callback) + self._test_network_error(mock_client, callback) - def test_network_error_on_insert(self): + def test_network_error_on_insert(self, mock_client): def callback(client): return client.db.collection.insert_one({}) - self._test_network_error(callback) + self._test_network_error(mock_client, callback) - def test_network_error_on_update(self): + def test_network_error_on_update(self, mock_client): def callback(client): return client.db.collection.update_one({}, {"$unset": "x"}) - self._test_network_error(callback) + self._test_network_error(mock_client, callback) - def test_network_error_on_replace(self): + def test_network_error_on_replace(self, mock_client): def callback(client): return client.db.collection.replace_one({}, {}) - self._test_network_error(callback) + self._test_network_error(mock_client, callback) - def test_network_error_on_delete(self): + def test_network_error_on_delete(self, mock_client): def callback(client): return client.db.collection.delete_many({}) - self._test_network_error(callback) + self._test_network_error(mock_client, callback) -class TestClientPool(MockClientTest): - @client_context.require_connection - def test_rs_client_does_not_maintain_pool_to_arbiters(self): +@pytest.mark.usefixtures("require_integration") +@pytest.mark.integration +class TestClientPool: + def test_rs_client_does_not_maintain_pool_to_arbiters(self, mock_client): listener = CMAPListener() - c = MockClient.get_mock_client( + c = mock_client( standalones=[], members=["a:1", "b:2", "c:3", "d:4"], mongoses=[], @@ -2502,27 +2521,21 @@ def test_rs_client_does_not_maintain_pool_to_arbiters(self): minPoolSize=1, # minPoolSize event_listeners=[listener], ) - self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") - self.assertEqual(c.address, ("a", 1)) - self.assertEqual(c.arbiters, {("c", 3)}) - # Assert that we create 2 and only 2 pooled connections. + assert c.address == ("a", 1) + assert c.arbiters == {("c", 3)} listener.wait_for_event(monitoring.ConnectionReadyEvent, 2) - self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2) - # Assert that we do not create connections to arbiters. + assert listener.event_count(monitoring.ConnectionCreatedEvent) == 2 arbiter = c._topology.get_server_by_address(("c", 3)) - self.assertFalse(arbiter.pool.conns) - # Assert that we do not create connections to unknown servers. + assert not arbiter.pool.conns arbiter = c._topology.get_server_by_address(("d", 4)) - self.assertFalse(arbiter.pool.conns) - # Arbiter pool is not marked ready. - self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 2) + assert not arbiter.pool.conns + assert listener.event_count(monitoring.PoolReadyEvent) == 2 - @client_context.require_connection - def test_direct_client_maintains_pool_to_arbiter(self): + def test_direct_client_maintains_pool_to_arbiter(self, mock_client): listener = CMAPListener() - c = MockClient.get_mock_client( + c = mock_client( standalones=[], members=["a:1", "b:2", "c:3"], mongoses=[], @@ -2532,18 +2545,11 @@ def test_direct_client_maintains_pool_to_arbiter(self): minPoolSize=1, # minPoolSize event_listeners=[listener], ) - self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 1, "connect") - self.assertEqual(c.address, ("c", 3)) - # Assert that we create 1 pooled connection. + assert c.address == ("c", 3) listener.wait_for_event(monitoring.ConnectionReadyEvent, 1) - self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1) + assert listener.event_count(monitoring.ConnectionCreatedEvent) == 1 arbiter = c._topology.get_server_by_address(("c", 3)) - self.assertEqual(len(arbiter.pool.conns), 1) - # Arbiter pool is marked ready. - self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 1) - - -if __name__ == "__main__": - unittest.main() + assert len(arbiter.pool.conns) == 1 + assert listener.event_count(monitoring.PoolReadyEvent) == 1 diff --git a/test/test_custom_types.py b/test/test_custom_types.py index 6771ea25f9..65f7994fb1 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -25,8 +25,7 @@ sys.path[0:0] = [""] -from test import client_context, unittest -from test.test_client import IntegrationTest +from test import IntegrationTest, client_context, unittest from bson import ( _BUILT_IN_TYPES, diff --git a/test/utils.py b/test/utils.py index 69154bc63b..bcbb1f9759 100644 --- a/test/utils.py +++ b/test/utils.py @@ -32,9 +32,10 @@ from collections import abc, defaultdict from functools import partial from test import client_context, db_pwd, db_user -from test.asynchronous import async_client_context from typing import Any, List +import pytest + from bson import json_util from bson.objectid import ObjectId from bson.son import SON @@ -810,7 +811,7 @@ def frequent_thread_switches(): sys.setswitchinterval(interval) -def lazy_client_trial(reset, target, test, get_client): +def lazy_client_trial(reset, target, test, get_client, client_context): """Test concurrent operations on a lazily-connecting client. `reset` takes a collection and resets it for the next trial. @@ -1022,3 +1023,10 @@ async def async_set_fail_point(client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) await client.admin.command(cmd) + + +def _default_pytest_mark(is_sync: bool): + if is_sync: + return pytest.mark.default + else: + return pytest.mark.asyncio(loop_scope="session") diff --git a/tools/synchro.py b/tools/synchro.py index dbcbbd1351..71b13c94f3 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -75,6 +75,8 @@ "AsyncPyMongoTestCase": "PyMongoTestCase", "AsyncMockClientTest": "MockClientTest", "async_client_context": "client_context", + "async_client": "client", + "async_mock_client": "mock_client", "async_setup": "setup", "asyncSetUp": "setUp", "asyncTearDown": "tearDown", @@ -121,6 +123,10 @@ "_async_cond_wait": "_cond_wait", } +removals: set[str] = { + 'loop_scope="session"', +} + docstring_replacements: dict[tuple[str, str], str] = { ("MongoClient", "connect"): """If ``True`` (the default), immediately begin connecting to MongoDB in the background. Otherwise connect @@ -187,6 +193,7 @@ def async_only_test(f: str) -> bool: "test_bulk.py", "test_change_stream.py", "test_client.py", + "test_client_pytest.py", "test_client_bulk_write.py", "test_client_context.py", "test_collation.py", @@ -229,7 +236,9 @@ def process_files( if file in docstring_translate_files: lines = translate_docstrings(lines) if file in sync_test_files: - translate_imports(lines) + lines = translate_imports(lines) + if file in sync_test_files: + lines = apply_removals(lines) f.seek(0) f.writelines(lines) f.truncate() @@ -331,6 +340,18 @@ def translate_docstrings(lines: list[str]) -> list[str]: return [line for line in lines if line != "DOCSTRING_REMOVED"] +def apply_removals(lines: list[str]) -> list[str]: + tokens_to_remove = [line for line in lines if any(t in line for t in removals)] + for token in removals: + for line in tokens_to_remove: + index = lines.index(line) + if token + ", " in line: + lines[index] = line.replace(token + ", ", "") + else: + lines[index] = line.replace(token, "") + return lines + + def unasync_directory(files: list[str], src: str, dest: str, replacements: dict[str, str]) -> None: unasync_files( files,