From 6afd401705d077a61db2a9191af900087bcf5f1c Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 7 Feb 2025 11:14:47 -0500 Subject: [PATCH 1/2] PYTHON-5092 - Convert test.test_pooling to async --- test/asynchronous/test_pooling.py | 599 ++++++++++++++++++++++++++++++ test/test_pooling.py | 103 +++-- test/utils.py | 5 + tools/synchro.py | 2 + 4 files changed, 655 insertions(+), 54 deletions(-) create mode 100644 test/asynchronous/test_pooling.py diff --git a/test/asynchronous/test_pooling.py b/test/asynchronous/test_pooling.py new file mode 100644 index 0000000000..b08ba4d858 --- /dev/null +++ b/test/asynchronous/test_pooling.py @@ -0,0 +1,599 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test built in connection-pooling with threads.""" +from __future__ import annotations + +import asyncio +import gc +import random +import socket +import sys +import time +from test.asynchronous.helpers import ConcurrentRunner + +from bson.codec_options import DEFAULT_CODEC_OPTIONS +from bson.son import SON +from pymongo import AsyncMongoClient, message, timeout +from pymongo.errors import AutoReconnect, ConnectionFailure, DuplicateKeyError +from pymongo.hello import HelloCompat +from pymongo.lock import _async_create_lock + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils import async_get_pool, async_joinall, delay + +from pymongo.asynchronous.pool import Pool, PoolOptions +from pymongo.socket_checker import SocketChecker + +_IS_SYNC = False + + +N = 10 +DB = "pymongo-pooling-tests" + + +async def gc_collect_until_done(tasks, timeout=60): + start = time.time() + running = list(tasks) + while running: + assert (time.time() - start) < timeout, "Tasks timed out" + for t in running: + await t.join(0.1) + if not t.is_alive(): + running.remove(t) + gc.collect() + + +class MongoTask(ConcurrentRunner): + """A thread/Task that uses a AsyncMongoClient.""" + + def __init__(self, client): + super().__init__() + self.daemon = True # Don't hang whole test if task hangs. + self.client = client + self.db = self.client[DB] + self.passed = False + + async def run(self): + await self.run_mongo_thread() + self.passed = True + + async def run_mongo_thread(self): + raise NotImplementedError + + +class InsertOneAndFind(MongoTask): + async def run_mongo_thread(self): + for _ in range(N): + rand = random.randint(0, N) + _id = (await self.db.sf.insert_one({"x": rand})).inserted_id + assert rand == (await self.db.sf.find_one(_id))["x"] + + +class Unique(MongoTask): + async def run_mongo_thread(self): + for _ in range(N): + await self.db.unique.insert_one({}) # no error + + +class NonUnique(MongoTask): + async def run_mongo_thread(self): + for _ in range(N): + try: + await self.db.unique.insert_one({"_id": "jesse"}) + except DuplicateKeyError: + pass + else: + raise AssertionError("Should have raised DuplicateKeyError") + + +class SocketGetter(MongoTask): + """Utility for TestPooling. + + Checks out a socket and holds it forever. Used in + test_no_wait_queue_timeout. + """ + + def __init__(self, client, pool): + super().__init__(client) + self.state = "init" + self.pool = pool + self.sock = None + + async def run_mongo_thread(self): + self.state = "get_socket" + + # Call 'pin_cursor' so we can hold the socket. + async with self.pool.checkout() as sock: + sock.pin_cursor() + self.sock = sock + + self.state = "connection" + + def __del__(self): + if self.sock: + self.sock.close_conn(None) + + +async def run_cases(client, cases): + tasks = [] + n_runs = 5 + + for case in cases: + for _i in range(n_runs): + t = case(client) + await t.start() + tasks.append(t) + + for t in tasks: + await t.join() + + for t in tasks: + assert t.passed, "%s.run() threw an exception" % repr(t) + + +class _TestPoolingBase(AsyncIntegrationTest): + """Base class for all connection-pool tests.""" + + @async_client_context.require_connection + async def asyncSetUp(self): + await super().asyncSetUp() + self.c = await self.async_rs_or_single_client() + db = self.c[DB] + await db.unique.drop() + await db.test.drop() + await db.unique.insert_one({"_id": "jesse"}) + await db.test.insert_many([{} for _ in range(10)]) + + async def asyncTearDown(self): + await self.c.close() + await super().asyncTearDown() + + async def create_pool(self, pair=None, *args, **kwargs): + if pair is None: + pair = (await async_client_context.host, await async_client_context.port) + # Start the pool with the correct ssl options. + pool_options = async_client_context.client._topology_settings.pool_options + kwargs["ssl_context"] = pool_options._ssl_context + kwargs["tls_allow_invalid_hostnames"] = pool_options.tls_allow_invalid_hostnames + kwargs["server_api"] = pool_options.server_api + pool = Pool(pair, PoolOptions(*args, **kwargs)) + await pool.ready() + return pool + + +class TestPooling(_TestPoolingBase): + async def test_max_pool_size_validation(self): + host, port = await async_client_context.host, await async_client_context.port + self.assertRaises(ValueError, AsyncMongoClient, host=host, port=port, maxPoolSize=-1) + + self.assertRaises(ValueError, AsyncMongoClient, host=host, port=port, maxPoolSize="foo") + + c = AsyncMongoClient(host=host, port=port, maxPoolSize=100, connect=False) + self.assertEqual(c.options.pool_options.max_pool_size, 100) + + async def test_no_disconnect(self): + await run_cases(self.c, [NonUnique, Unique, InsertOneAndFind]) + + async def test_pool_reuses_open_socket(self): + # Test Pool's _check_closed() method doesn't close a healthy socket. + cx_pool = await self.create_pool(max_pool_size=10) + cx_pool._check_interval_seconds = 0 # Always check. + async with cx_pool.checkout() as conn: + pass + + async with cx_pool.checkout() as new_connection: + self.assertEqual(conn, new_connection) + + self.assertEqual(1, len(cx_pool.conns)) + + async def test_get_socket_and_exception(self): + # get_socket() returns socket after a non-network error. + cx_pool = await self.create_pool(max_pool_size=1, wait_queue_timeout=1) + with self.assertRaises(ZeroDivisionError): + async with cx_pool.checkout() as conn: + 1 / 0 + + # Socket was returned, not closed. + async with cx_pool.checkout() as new_connection: + self.assertEqual(conn, new_connection) + + self.assertEqual(1, len(cx_pool.conns)) + + async def test_pool_removes_closed_socket(self): + # Test that Pool removes explicitly closed socket. + cx_pool = await self.create_pool() + + async with cx_pool.checkout() as conn: + # Use Connection's API to close the socket. + conn.close_conn(None) + + self.assertEqual(0, len(cx_pool.conns)) + + async def test_pool_removes_dead_socket(self): + # Test that Pool removes dead socket and the socket doesn't return + # itself PYTHON-344 + cx_pool = await self.create_pool(max_pool_size=1, wait_queue_timeout=1) + cx_pool._check_interval_seconds = 0 # Always check. + + async with cx_pool.checkout() as conn: + # Simulate a closed socket without telling the Connection it's + # closed. + conn.conn.close() + self.assertTrue(conn.conn_closed()) + + async with cx_pool.checkout() as new_connection: + self.assertEqual(0, len(cx_pool.conns)) + self.assertNotEqual(conn, new_connection) + + self.assertEqual(1, len(cx_pool.conns)) + + # Semaphore was released. + async with cx_pool.checkout(): + pass + + async def test_socket_closed(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect((await async_client_context.host, await async_client_context.port)) + socket_checker = SocketChecker() + self.assertFalse(socket_checker.socket_closed(s)) + s.close() + self.assertTrue(socket_checker.socket_closed(s)) + + async def test_socket_checker(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect((await async_client_context.host, await async_client_context.port)) + socket_checker = SocketChecker() + # Socket has nothing to read. + self.assertFalse(socket_checker.select(s, read=True)) + self.assertFalse(socket_checker.select(s, read=True, timeout=0)) + self.assertFalse(socket_checker.select(s, read=True, timeout=0.05)) + # Socket is writable. + self.assertTrue(socket_checker.select(s, write=True, timeout=None)) + self.assertTrue(socket_checker.select(s, write=True)) + self.assertTrue(socket_checker.select(s, write=True, timeout=0)) + self.assertTrue(socket_checker.select(s, write=True, timeout=0.05)) + # Make the socket readable + _, msg, _ = message._query( + 0, "admin.$cmd", 0, -1, SON([("ping", 1)]), None, DEFAULT_CODEC_OPTIONS + ) + s.sendall(msg) + # Block until the socket is readable. + self.assertTrue(socket_checker.select(s, read=True, timeout=None)) + self.assertTrue(socket_checker.select(s, read=True)) + self.assertTrue(socket_checker.select(s, read=True, timeout=0)) + self.assertTrue(socket_checker.select(s, read=True, timeout=0.05)) + # Socket is still writable. + self.assertTrue(socket_checker.select(s, write=True, timeout=None)) + self.assertTrue(socket_checker.select(s, write=True)) + self.assertTrue(socket_checker.select(s, write=True, timeout=0)) + self.assertTrue(socket_checker.select(s, write=True, timeout=0.05)) + s.close() + self.assertTrue(socket_checker.socket_closed(s)) + + async def test_return_socket_after_reset(self): + pool = await self.create_pool() + async with pool.checkout() as sock: + self.assertEqual(pool.active_sockets, 1) + self.assertEqual(pool.operation_count, 1) + await pool.reset() + + self.assertTrue(sock.closed) + self.assertEqual(0, len(pool.conns)) + self.assertEqual(pool.active_sockets, 0) + self.assertEqual(pool.operation_count, 0) + + async def test_pool_check(self): + # Test that Pool recovers from two connection failures in a row. + # This exercises code at the end of Pool._check(). + cx_pool = await self.create_pool(max_pool_size=1, connect_timeout=1, wait_queue_timeout=1) + cx_pool._check_interval_seconds = 0 # Always check. + self.addAsyncCleanup(cx_pool.close) + + async with cx_pool.checkout() as conn: + # Simulate a closed socket without telling the Connection it's + # closed. + conn.conn.close() + + # Swap pool's address with a bad one. + address, cx_pool.address = cx_pool.address, ("foo.com", 1234) + with self.assertRaises(AutoReconnect): + async with cx_pool.checkout(): + pass + + # Back to normal, semaphore was correctly released. + cx_pool.address = address + async with cx_pool.checkout(): + pass + + async def test_wait_queue_timeout(self): + wait_queue_timeout = 2 # Seconds + pool = await self.create_pool(max_pool_size=1, wait_queue_timeout=wait_queue_timeout) + self.addAsyncCleanup(pool.close) + + async with pool.checkout(): + start = time.time() + with self.assertRaises(ConnectionFailure): + async with pool.checkout(): + pass + + duration = time.time() - start + self.assertTrue( + abs(wait_queue_timeout - duration) < 1, + f"Waited {duration:.2f} seconds for a socket, expected {wait_queue_timeout:f}", + ) + + async def test_no_wait_queue_timeout(self): + # Verify get_socket() with no wait_queue_timeout blocks forever. + pool = await self.create_pool(max_pool_size=1) + self.addAsyncCleanup(pool.close) + + # Reach max_size. + async with pool.checkout() as s1: + t = SocketGetter(self.c, pool) + await t.start() + while t.state != "get_socket": + await asyncio.sleep(0.1) + + await asyncio.sleep(1) + self.assertEqual(t.state, "get_socket") + + while t.state != "connection": + await asyncio.sleep(0.1) + + self.assertEqual(t.state, "connection") + self.assertEqual(t.sock, s1) + + async def test_checkout_more_than_max_pool_size(self): + pool = await self.create_pool(max_pool_size=2) + + socks = [] + for _ in range(2): + # Call 'pin_cursor' so we can hold the socket. + async with pool.checkout() as sock: + sock.pin_cursor() + socks.append(sock) + + tasks = [] + for _ in range(30): + t = SocketGetter(self.c, pool) + await t.start() + tasks.append(t) + await asyncio.sleep(1) + for t in tasks: + self.assertEqual(t.state, "get_socket") + + for socket_info in socks: + socket_info.close_conn(None) + + async def test_maxConnecting(self): + client = await self.async_rs_or_single_client() + await self.client.test.test.insert_one({}) + self.addAsyncCleanup(self.client.test.test.delete_many, {}) + pool = await async_get_pool(client) + docs = [] + + # Run 50 short running operations + async def find_one(): + docs.append(await client.test.test.find_one({})) + + tasks = [ConcurrentRunner(target=find_one) for _ in range(50)] + for task in tasks: + await task.start() + for task in tasks: + await task.join(10) + + self.assertEqual(len(docs), 50) + self.assertLessEqual(len(pool.conns), 50) + # TLS and auth make connection establishment more expensive than + # the query which leads to more threads hitting maxConnecting. + # The end result is fewer total connections and better latency. + if async_client_context.tls and async_client_context.auth_enabled: + self.assertLessEqual(len(pool.conns), 30) + else: + self.assertLessEqual(len(pool.conns), 50) + # MongoDB 4.4.1 with auth + ssl: + # maxConnecting = 2: 6 connections in ~0.231+ seconds + # maxConnecting = unbounded: 50 connections in ~0.642+ seconds + # + # MongoDB 4.4.1 with no-auth no-ssl Python 3.8: + # maxConnecting = 2: 15-22 connections in ~0.108+ seconds + # maxConnecting = unbounded: 30+ connections in ~0.140+ seconds + print(len(pool.conns)) + + @async_client_context.require_failCommand_appName + async def test_csot_timeout_message(self): + client = await self.async_rs_or_single_client(appName="connectionTimeoutApp") + # Mock an operation failing due to pymongo.timeout(). + mock_connection_timeout = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "blockConnection": True, + "blockTimeMS": 1000, + "failCommands": ["find"], + "appName": "connectionTimeoutApp", + }, + } + + await client.db.t.insert_one({"x": 1}) + + async with self.fail_point(mock_connection_timeout): + with self.assertRaises(Exception) as error: + with timeout(0.5): + await client.db.t.find_one({"$where": delay(2)}) + + self.assertTrue("(configured timeouts: timeoutMS: 500.0ms" in str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_socket_timeout_message(self): + client = await self.async_rs_or_single_client( + socketTimeoutMS=500, appName="connectionTimeoutApp" + ) + # Mock an operation failing due to socketTimeoutMS. + mock_connection_timeout = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "blockConnection": True, + "blockTimeMS": 1000, + "failCommands": ["find"], + "appName": "connectionTimeoutApp", + }, + } + + await client.db.t.insert_one({"x": 1}) + + async with self.fail_point(mock_connection_timeout): + with self.assertRaises(Exception) as error: + await client.db.t.find_one({"$where": delay(2)}) + + self.assertTrue( + "(configured timeouts: socketTimeoutMS: 500.0ms, connectTimeoutMS: 20000.0ms)" + in str(error.exception) + ) + + @async_client_context.require_failCommand_appName + async def test_connection_timeout_message(self): + # Mock a connection creation failing due to timeout. + mock_connection_timeout = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "blockConnection": True, + "blockTimeMS": 1000, + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "appName": "connectionTimeoutApp", + }, + } + + client = await self.async_rs_or_single_client( + connectTimeoutMS=500, + socketTimeoutMS=500, + appName="connectionTimeoutApp", + heartbeatFrequencyMS=1000000, + ) + await client.admin.command("ping") + pool = await async_get_pool(client) + await pool.reset_without_pause() + async with self.fail_point(mock_connection_timeout): + with self.assertRaises(Exception) as error: + await client.admin.command("ping") + + self.assertTrue( + "(configured timeouts: socketTimeoutMS: 500.0ms, connectTimeoutMS: 500.0ms)" + in str(error.exception) + ) + + +class TestPoolMaxSize(_TestPoolingBase): + async def test_max_pool_size(self): + max_pool_size = 4 + c = await self.async_rs_or_single_client(maxPoolSize=max_pool_size) + collection = c[DB].test + + # Need one document. + await collection.drop() + await collection.insert_one({}) + + # ntasks had better be much larger than max_pool_size to ensure that + # max_pool_size connections are actually required at some point in this + # test's execution. + cx_pool = await async_get_pool(c) + ntasks = 10 + tasks = [] + lock = _async_create_lock() + self.n_passed = 0 + + async def f(): + for _ in range(5): + await collection.find_one({"$where": delay(0.1)}) + assert len(cx_pool.conns) <= max_pool_size + + async with lock: + self.n_passed += 1 + + for _i in range(ntasks): + t = ConcurrentRunner(target=f) + tasks.append(t) + await t.start() + + await async_joinall(tasks) + self.assertEqual(ntasks, self.n_passed) + self.assertTrue(len(cx_pool.conns) > 1) + self.assertEqual(0, cx_pool.requests) + + async def test_max_pool_size_none(self): + c = await self.async_rs_or_single_client(maxPoolSize=None) + collection = c[DB].test + + # Need one document. + await collection.drop() + await collection.insert_one({}) + + cx_pool = await async_get_pool(c) + ntasks = 10 + tasks = [] + lock = _async_create_lock() + self.n_passed = 0 + + async def f(): + for _ in range(5): + await collection.find_one({"$where": delay(0.1)}) + + async with lock: + self.n_passed += 1 + + for _i in range(ntasks): + t = ConcurrentRunner(target=f) + tasks.append(t) + await t.start() + + await async_joinall(tasks) + self.assertEqual(ntasks, self.n_passed) + self.assertTrue(len(cx_pool.conns) > 1) + self.assertEqual(cx_pool.max_pool_size, float("inf")) + + async def test_max_pool_size_zero(self): + c = await self.async_rs_or_single_client(maxPoolSize=0) + pool = await async_get_pool(c) + self.assertEqual(pool.max_pool_size, float("inf")) + + async def test_max_pool_size_with_connection_failure(self): + # The pool acquires its semaphore before attempting to connect; ensure + # it releases the semaphore on connection failure. + test_pool = Pool( + ("somedomainthatdoesntexist.org", 27017), + PoolOptions(max_pool_size=1, connect_timeout=1, socket_timeout=1, wait_queue_timeout=1), + ) + await test_pool.ready() + + # First call to get_socket fails; if pool doesn't release its semaphore + # then the second call raises "ConnectionFailure: Timed out waiting for + # socket from pool" instead of AutoReconnect. + for _i in range(2): + with self.assertRaises(AutoReconnect) as context: + async with test_pool.checkout(): + pass + + # Testing for AutoReconnect instead of ConnectionFailure, above, + # is sufficient right *now* to catch a semaphore leak. But that + # seems error-prone, so check the message too. + self.assertNotIn("waiting for socket from pool", str(context.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_pooling.py b/test/test_pooling.py index 3b867965bd..41e7fc3fcb 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -15,18 +15,20 @@ """Test built in connection-pooling with threads.""" from __future__ import annotations +import asyncio import gc import random import socket import sys -import threading import time +from test.helpers import ConcurrentRunner from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.son import SON from pymongo import MongoClient, message, timeout from pymongo.errors import AutoReconnect, ConnectionFailure, DuplicateKeyError from pymongo.hello import HelloCompat +from pymongo.lock import _create_lock sys.path[0:0] = [""] @@ -36,21 +38,18 @@ from pymongo.socket_checker import SocketChecker from pymongo.synchronous.pool import Pool, PoolOptions - -@client_context.require_connection -def setUpModule(): - pass +_IS_SYNC = True N = 10 DB = "pymongo-pooling-tests" -def gc_collect_until_done(threads, timeout=60): +def gc_collect_until_done(tasks, timeout=60): start = time.time() - running = list(threads) + running = list(tasks) while running: - assert (time.time() - start) < timeout, "Threads timed out" + assert (time.time() - start) < timeout, "Tasks timed out" for t in running: t.join(0.1) if not t.is_alive(): @@ -58,12 +57,12 @@ def gc_collect_until_done(threads, timeout=60): gc.collect() -class MongoThread(threading.Thread): - """A thread that uses a MongoClient.""" +class MongoTask(ConcurrentRunner): + """A thread/Task that uses a MongoClient.""" def __init__(self, client): super().__init__() - self.daemon = True # Don't hang whole test if thread hangs. + self.daemon = True # Don't hang whole test if task hangs. self.client = client self.db = self.client[DB] self.passed = False @@ -76,21 +75,21 @@ def run_mongo_thread(self): raise NotImplementedError -class InsertOneAndFind(MongoThread): +class InsertOneAndFind(MongoTask): def run_mongo_thread(self): for _ in range(N): rand = random.randint(0, N) - _id = self.db.sf.insert_one({"x": rand}).inserted_id - assert rand == self.db.sf.find_one(_id)["x"] + _id = (self.db.sf.insert_one({"x": rand})).inserted_id + assert rand == (self.db.sf.find_one(_id))["x"] -class Unique(MongoThread): +class Unique(MongoTask): def run_mongo_thread(self): for _ in range(N): self.db.unique.insert_one({}) # no error -class NonUnique(MongoThread): +class NonUnique(MongoTask): def run_mongo_thread(self): for _ in range(N): try: @@ -101,7 +100,7 @@ def run_mongo_thread(self): raise AssertionError("Should have raised DuplicateKeyError") -class SocketGetter(MongoThread): +class SocketGetter(MongoTask): """Utility for TestPooling. Checks out a socket and holds it forever. Used in @@ -130,25 +129,26 @@ def __del__(self): def run_cases(client, cases): - threads = [] + tasks = [] n_runs = 5 for case in cases: for _i in range(n_runs): t = case(client) t.start() - threads.append(t) + tasks.append(t) - for t in threads: + for t in tasks: t.join() - for t in threads: + for t in tasks: assert t.passed, "%s.run() threw an exception" % repr(t) class _TestPoolingBase(IntegrationTest): """Base class for all connection-pool tests.""" + @client_context.require_connection def setUp(self): super().setUp() self.c = self.rs_or_single_client() @@ -162,7 +162,9 @@ def tearDown(self): self.c.close() super().tearDown() - def create_pool(self, pair=(client_context.host, client_context.port), *args, **kwargs): + def create_pool(self, pair=None, *args, **kwargs): + if pair is None: + pair = (client_context.host, client_context.port) # Start the pool with the correct ssl options. pool_options = client_context.client._topology_settings.pool_options kwargs["ssl_context"] = pool_options._ssl_context @@ -365,13 +367,13 @@ def test_checkout_more_than_max_pool_size(self): sock.pin_cursor() socks.append(sock) - threads = [] + tasks = [] for _ in range(30): t = SocketGetter(self.c, pool) t.start() - threads.append(t) + tasks.append(t) time.sleep(1) - for t in threads: + for t in tasks: self.assertEqual(t.state, "get_socket") for socket_info in socks: @@ -379,7 +381,6 @@ def test_checkout_more_than_max_pool_size(self): def test_maxConnecting(self): client = self.rs_or_single_client() - self.addCleanup(client.close) self.client.test.test.insert_one({}) self.addCleanup(self.client.test.test.delete_many, {}) pool = get_pool(client) @@ -389,11 +390,11 @@ def test_maxConnecting(self): def find_one(): docs.append(client.test.test.find_one({})) - threads = [threading.Thread(target=find_one) for _ in range(50)] - for thread in threads: - thread.start() - for thread in threads: - thread.join(10) + tasks = [ConcurrentRunner(target=find_one) for _ in range(50)] + for task in tasks: + task.start() + for task in tasks: + task.join(10) self.assertEqual(len(docs), 50) self.assertLessEqual(len(pool.conns), 50) @@ -416,7 +417,6 @@ def find_one(): @client_context.require_failCommand_appName def test_csot_timeout_message(self): client = self.rs_or_single_client(appName="connectionTimeoutApp") - self.addCleanup(client.close) # Mock an operation failing due to pymongo.timeout(). mock_connection_timeout = { "configureFailPoint": "failCommand", @@ -441,7 +441,6 @@ def test_csot_timeout_message(self): @client_context.require_failCommand_appName def test_socket_timeout_message(self): client = self.rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp") - self.addCleanup(client.close) # Mock an operation failing due to socketTimeoutMS. mock_connection_timeout = { "configureFailPoint": "failCommand", @@ -485,7 +484,6 @@ def test_connection_timeout_message(self): appName="connectionTimeoutApp", heartbeatFrequencyMS=1000000, ) - self.addCleanup(client.close) client.admin.command("ping") pool = get_pool(client) pool.reset_without_pause() @@ -503,20 +501,19 @@ class TestPoolMaxSize(_TestPoolingBase): def test_max_pool_size(self): max_pool_size = 4 c = self.rs_or_single_client(maxPoolSize=max_pool_size) - self.addCleanup(c.close) collection = c[DB].test # Need one document. collection.drop() collection.insert_one({}) - # nthreads had better be much larger than max_pool_size to ensure that + # ntasks had better be much larger than max_pool_size to ensure that # max_pool_size connections are actually required at some point in this # test's execution. cx_pool = get_pool(c) - nthreads = 10 - threads = [] - lock = threading.Lock() + ntasks = 10 + tasks = [] + lock = _create_lock() self.n_passed = 0 def f(): @@ -527,19 +524,18 @@ def f(): with lock: self.n_passed += 1 - for _i in range(nthreads): - t = threading.Thread(target=f) - threads.append(t) + for _i in range(ntasks): + t = ConcurrentRunner(target=f) + tasks.append(t) t.start() - joinall(threads) - self.assertEqual(nthreads, self.n_passed) + joinall(tasks) + self.assertEqual(ntasks, self.n_passed) self.assertTrue(len(cx_pool.conns) > 1) self.assertEqual(0, cx_pool.requests) def test_max_pool_size_none(self): c = self.rs_or_single_client(maxPoolSize=None) - self.addCleanup(c.close) collection = c[DB].test # Need one document. @@ -547,9 +543,9 @@ def test_max_pool_size_none(self): collection.insert_one({}) cx_pool = get_pool(c) - nthreads = 10 - threads = [] - lock = threading.Lock() + ntasks = 10 + tasks = [] + lock = _create_lock() self.n_passed = 0 def f(): @@ -559,19 +555,18 @@ def f(): with lock: self.n_passed += 1 - for _i in range(nthreads): - t = threading.Thread(target=f) - threads.append(t) + for _i in range(ntasks): + t = ConcurrentRunner(target=f) + tasks.append(t) t.start() - joinall(threads) - self.assertEqual(nthreads, self.n_passed) + joinall(tasks) + self.assertEqual(ntasks, self.n_passed) self.assertTrue(len(cx_pool.conns) > 1) self.assertEqual(cx_pool.max_pool_size, float("inf")) def test_max_pool_size_zero(self): c = self.rs_or_single_client(maxPoolSize=0) - self.addCleanup(c.close) pool = get_pool(c) self.assertEqual(pool.max_pool_size, float("inf")) diff --git a/test/utils.py b/test/utils.py index 5c1e0bfb7c..40eec01cb4 100644 --- a/test/utils.py +++ b/test/utils.py @@ -666,6 +666,11 @@ def joinall(threads): assert not t.is_alive(), "Thread %s hung" % t +async def async_joinall(tasks): + """Join threads with a 5-minute timeout, assert joins succeeded""" + await asyncio.wait([t.task for t in tasks if t is not None], timeout=300) + + def wait_until(predicate, success_description, timeout=10): """Wait up to 10 seconds (by default) for predicate to be true. diff --git a/tools/synchro.py b/tools/synchro.py index 7e7aeec3a4..2e9d2c8289 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -123,6 +123,7 @@ "AsyncMockConnection": "MockConnection", "AsyncMockPool": "MockPool", "create_async_event": "create_event", + "async_joinall": "joinall", } docstring_replacements: dict[tuple[str, str], str] = { @@ -223,6 +224,7 @@ def async_only_test(f: str) -> bool: "test_monitoring.py", "test_mongos_load_balancing.py", "test_on_demand_csfle.py", + "test_pooling.py", "test_raw_bson.py", "test_read_concern.py", "test_read_preferences.py", From 0b9e0961388de7b4e64fec7860d10e9009aece5e Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 13 Feb 2025 09:39:14 -0800 Subject: [PATCH 2/2] Address review --- test/asynchronous/test_pooling.py | 6 +----- test/test_pooling.py | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/test/asynchronous/test_pooling.py b/test/asynchronous/test_pooling.py index b08ba4d858..09b8fb0853 100644 --- a/test/asynchronous/test_pooling.py +++ b/test/asynchronous/test_pooling.py @@ -21,7 +21,6 @@ import socket import sys import time -from test.asynchronous.helpers import ConcurrentRunner from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.son import SON @@ -33,6 +32,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.helpers import ConcurrentRunner from test.utils import async_get_pool, async_joinall, delay from pymongo.asynchronous.pool import Pool, PoolOptions @@ -158,10 +158,6 @@ async def asyncSetUp(self): await db.unique.insert_one({"_id": "jesse"}) await db.test.insert_many([{} for _ in range(10)]) - async def asyncTearDown(self): - await self.c.close() - await super().asyncTearDown() - async def create_pool(self, pair=None, *args, **kwargs): if pair is None: pair = (await async_client_context.host, await async_client_context.port) diff --git a/test/test_pooling.py b/test/test_pooling.py index 41e7fc3fcb..5d23b85f23 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -21,7 +21,6 @@ import socket import sys import time -from test.helpers import ConcurrentRunner from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.son import SON @@ -33,6 +32,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest +from test.helpers import ConcurrentRunner from test.utils import delay, get_pool, joinall from pymongo.socket_checker import SocketChecker @@ -158,10 +158,6 @@ def setUp(self): db.unique.insert_one({"_id": "jesse"}) db.test.insert_many([{} for _ in range(10)]) - def tearDown(self): - self.c.close() - super().tearDown() - def create_pool(self, pair=None, *args, **kwargs): if pair is None: pair = (client_context.host, client_context.port)