diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 5c7ea9ca..97ce885e 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -7,20 +7,16 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable, Callable import functools import inspect import logging import time +import warnings +from collections.abc import Awaitable, Callable from types import TracebackType from typing import Any, Optional, Type -import warnings - -from . import compat -from . import connection -from . import exceptions -from . import protocol +from . import compat, connection, exceptions, protocol logger = logging.getLogger(__name__) @@ -338,27 +334,46 @@ class Pool: """ __slots__ = ( - '_queue', '_loop', '_minsize', '_maxsize', - '_init', '_connect', '_reset', '_connect_args', '_connect_kwargs', - '_holders', '_initialized', '_initializing', '_closing', - '_closed', '_connection_class', '_record_class', '_generation', - '_setup', '_max_queries', '_max_inactive_connection_lifetime' + "_queue", + "_loop", + "_minsize", + "_maxsize", + "_init", + "_connect", + "_reset", + "_connect_args", + "_connect_kwargs", + "_holders", + "_initialized", + "_initializing", + "_closing", + "_closed", + "_connection_class", + "_record_class", + "_generation", + "_setup", + "_max_queries", + "_max_inactive_connection_lifetime", + "_pool_timeout", ) - def __init__(self, *connect_args, - min_size, - max_size, - max_queries, - max_inactive_connection_lifetime, - connect=None, - setup=None, - init=None, - reset=None, - loop, - connection_class, - record_class, - **connect_kwargs): - + def __init__( + self, + *connect_args, + min_size, + max_size, + max_queries, + max_inactive_connection_lifetime, + pool_timeout=None, + connect=None, + setup=None, + init=None, + reset=None, + loop, + connection_class, + record_class, + **connect_kwargs, + ): if len(connect_args) > 1: warnings.warn( "Passing multiple positional arguments to asyncpg.Pool " @@ -389,6 +404,11 @@ def __init__(self, *connect_args, 'max_inactive_connection_lifetime is expected to be greater ' 'or equal to zero') + if pool_timeout is not None and pool_timeout <= 0: + raise ValueError( + "pool_timeout is expected to be greater than zero or None" + ) + if not issubclass(connection_class, connection.Connection): raise TypeError( 'connection_class is expected to be a subclass of ' @@ -423,8 +443,10 @@ def __init__(self, *connect_args, self._reset = reset self._max_queries = max_queries - self._max_inactive_connection_lifetime = \ + self._max_inactive_connection_lifetime = ( max_inactive_connection_lifetime + ) + self._pool_timeout = pool_timeout async def _async__init__(self): if self._initialized: @@ -578,7 +600,7 @@ async def execute( self, query: str, *args, - timeout: Optional[float]=None, + timeout: Optional[float] = None, ) -> str: """Execute an SQL command (or commands). @@ -596,7 +618,7 @@ async def executemany( command: str, args, *, - timeout: Optional[float]=None, + timeout: Optional[float] = None, ): """Execute an SQL *command* for each sequence of arguments in *args*. @@ -853,6 +875,7 @@ def acquire(self, *, timeout=None): """Acquire a database connection from the pool. :param float timeout: A timeout for acquiring a Connection. + If not specified, defaults to the pool's *pool_timeout*. :return: An instance of :class:`~asyncpg.connection.Connection`. Can be used in an ``await`` expression or with an ``async with`` block. @@ -892,11 +915,16 @@ async def _acquire_impl(): raise exceptions.InterfaceError('pool is closing') self._check_init() - if timeout is None: + # Use pool_timeout as fallback if no timeout specified + effective_timeout = timeout or self._pool_timeout + + if effective_timeout is None: return await _acquire_impl() else: return await compat.wait_for( - _acquire_impl(), timeout=timeout) + _acquire_impl(), + timeout=effective_timeout + ) async def release(self, connection, *, timeout=None): """Release a database connection back to the pool. @@ -906,7 +934,8 @@ async def release(self, connection, *, timeout=None): :param float timeout: A timeout for releasing the connection. If not specified, defaults to the timeout provided in the corresponding call to the - :meth:`Pool.acquire() ` method. + :meth:`Pool.acquire() ` method, or + to the pool's *pool_timeout* if no acquire timeout was set. .. versionchanged:: 0.14.0 Added the *timeout* parameter. @@ -929,7 +958,7 @@ async def release(self, connection, *, timeout=None): ch = connection._holder if timeout is None: - timeout = ch._timeout + timeout = ch._timeout or self._pool_timeout # Use asyncio.shield() to guarantee that task cancellation # does not prevent the connection from being returned to the @@ -1065,26 +1094,32 @@ async def __aexit__( self.done = True con = self.connection self.connection = None - await self.pool.release(con) + # Use the acquire timeout if set, otherwise fall back to pool_timeout + release_timeout = self.timeout or self.pool._pool_timeout + await self.pool.release(con, timeout=release_timeout) def __await__(self): self.done = True return self.pool._acquire(self.timeout).__await__() -def create_pool(dsn=None, *, - min_size=10, - max_size=10, - max_queries=50000, - max_inactive_connection_lifetime=300.0, - connect=None, - setup=None, - init=None, - reset=None, - loop=None, - connection_class=connection.Connection, - record_class=protocol.Record, - **connect_kwargs): +def create_pool( + dsn=None, + *, + min_size=10, + max_size=10, + max_queries=50000, + max_inactive_connection_lifetime=300.0, + pool_timeout=None, + connect=None, + setup=None, + init=None, + reset=None, + loop=None, + connection_class=connection.Connection, + record_class=protocol.Record, + **connect_kwargs, +): r"""Create a connection pool. Can be used either with an ``async with`` block: @@ -1161,6 +1196,11 @@ def create_pool(dsn=None, *, Number of seconds after which inactive connections in the pool will be closed. Pass ``0`` to disable this mechanism. + :param float pool_timeout: + Default timeout for pool operations (connection acquire and release). + If not specified, pool operations may hang indefinitely. Individual + operations can override this with their own timeout parameters. + :param coroutine connect: A coroutine that is called instead of :func:`~asyncpg.connection.connect` whenever the pool needs to make a @@ -1238,6 +1278,7 @@ def create_pool(dsn=None, *, min_size=min_size, max_size=max_size, max_queries=max_queries, + pool_timeout=pool_timeout, loop=loop, connect=connect, setup=setup, diff --git a/tests/test_pool.py b/tests/test_pool.py index 3f10ae5c..7550f5f4 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -29,9 +29,9 @@ class SlowResetConnection(pg_connection.Connection): """Connection class to simulate races with Connection.reset().""" - async def reset(self, *, timeout=None): + async def _reset(self): await asyncio.sleep(0.2) - return await super().reset(timeout=timeout) + return await super()._reset() class SlowCancelConnection(pg_connection.Connection): @@ -1004,6 +1004,46 @@ async def worker(): conn = await pool.acquire(timeout=0.1) await pool.release(conn) + async def test_pool_timeout_acquire_timeout(self): + pool = await self.create_pool( + database='postgres', + min_size=1, + max_size=1, # Only 1 connection to force timeout + pool_timeout=0.1 + ) + + # First acquire the only connection + conn1 = await pool.acquire() + + # Now try to acquire another - should timeout due to pool_timeout + start_time = time.monotonic() + with self.assertRaises(asyncio.TimeoutError): + await pool.acquire() + end_time = time.monotonic() + + self.assertLess(end_time - start_time, 0.2) + + await pool.release(conn1) + await pool.close() + + async def test_pool_timeout_release_with_slow_reset(self): + pool = await self.create_pool( + database='postgres', + min_size=1, + max_size=1, + pool_timeout=0.1, + connection_class=SlowResetConnection, + ) + + start_time = time.monotonic() + with self.assertRaises(asyncio.TimeoutError): + conn = await pool.acquire() + await pool.release(conn) + end_time = time.monotonic() + + self.assertLess(end_time - start_time, 0.2) + await pool.close() + @unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') class TestPoolReconnectWithTargetSessionAttrs(tb.ClusterTestCase):