Skip to content

Commit 06ad990

Browse files
committed
Add pool_timeout parameter to prevent pool operation hangs (#431)
1 parent 5b14653 commit 06ad990

File tree

2 files changed

+130
-49
lines changed

2 files changed

+130
-49
lines changed

asyncpg/pool.py

Lines changed: 88 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,16 @@
77
from __future__ import annotations
88

99
import asyncio
10-
from collections.abc import Awaitable, Callable
1110
import functools
1211
import inspect
1312
import logging
1413
import time
14+
import warnings
15+
from collections.abc import Awaitable, Callable
1516
from types import TracebackType
1617
from typing import Any, Optional, Type
17-
import warnings
18-
19-
from . import compat
20-
from . import connection
21-
from . import exceptions
22-
from . import protocol
2318

19+
from . import compat, connection, exceptions, protocol
2420

2521
logger = logging.getLogger(__name__)
2622

@@ -338,27 +334,46 @@ class Pool:
338334
"""
339335

340336
__slots__ = (
341-
'_queue', '_loop', '_minsize', '_maxsize',
342-
'_init', '_connect', '_reset', '_connect_args', '_connect_kwargs',
343-
'_holders', '_initialized', '_initializing', '_closing',
344-
'_closed', '_connection_class', '_record_class', '_generation',
345-
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
337+
"_queue",
338+
"_loop",
339+
"_minsize",
340+
"_maxsize",
341+
"_init",
342+
"_connect",
343+
"_reset",
344+
"_connect_args",
345+
"_connect_kwargs",
346+
"_holders",
347+
"_initialized",
348+
"_initializing",
349+
"_closing",
350+
"_closed",
351+
"_connection_class",
352+
"_record_class",
353+
"_generation",
354+
"_setup",
355+
"_max_queries",
356+
"_max_inactive_connection_lifetime",
357+
"_pool_timeout",
346358
)
347359

348-
def __init__(self, *connect_args,
349-
min_size,
350-
max_size,
351-
max_queries,
352-
max_inactive_connection_lifetime,
353-
connect=None,
354-
setup=None,
355-
init=None,
356-
reset=None,
357-
loop,
358-
connection_class,
359-
record_class,
360-
**connect_kwargs):
361-
360+
def __init__(
361+
self,
362+
*connect_args,
363+
min_size,
364+
max_size,
365+
max_queries,
366+
max_inactive_connection_lifetime,
367+
pool_timeout=None,
368+
connect=None,
369+
setup=None,
370+
init=None,
371+
reset=None,
372+
loop,
373+
connection_class,
374+
record_class,
375+
**connect_kwargs,
376+
):
362377
if len(connect_args) > 1:
363378
warnings.warn(
364379
"Passing multiple positional arguments to asyncpg.Pool "
@@ -389,6 +404,11 @@ def __init__(self, *connect_args,
389404
'max_inactive_connection_lifetime is expected to be greater '
390405
'or equal to zero')
391406

407+
if pool_timeout is not None and pool_timeout <= 0:
408+
raise ValueError(
409+
"pool_timeout is expected to be greater than zero or None"
410+
)
411+
392412
if not issubclass(connection_class, connection.Connection):
393413
raise TypeError(
394414
'connection_class is expected to be a subclass of '
@@ -423,8 +443,10 @@ def __init__(self, *connect_args,
423443
self._reset = reset
424444

425445
self._max_queries = max_queries
426-
self._max_inactive_connection_lifetime = \
446+
self._max_inactive_connection_lifetime = (
427447
max_inactive_connection_lifetime
448+
)
449+
self._pool_timeout = pool_timeout
428450

429451
async def _async__init__(self):
430452
if self._initialized:
@@ -578,7 +600,7 @@ async def execute(
578600
self,
579601
query: str,
580602
*args,
581-
timeout: Optional[float]=None,
603+
timeout: Optional[float] = None,
582604
) -> str:
583605
"""Execute an SQL command (or commands).
584606
@@ -596,7 +618,7 @@ async def executemany(
596618
command: str,
597619
args,
598620
*,
599-
timeout: Optional[float]=None,
621+
timeout: Optional[float] = None,
600622
):
601623
"""Execute an SQL *command* for each sequence of arguments in *args*.
602624
@@ -853,6 +875,7 @@ def acquire(self, *, timeout=None):
853875
"""Acquire a database connection from the pool.
854876
855877
:param float timeout: A timeout for acquiring a Connection.
878+
If not specified, defaults to the pool's *pool_timeout*.
856879
:return: An instance of :class:`~asyncpg.connection.Connection`.
857880
858881
Can be used in an ``await`` expression or with an ``async with`` block.
@@ -892,11 +915,16 @@ async def _acquire_impl():
892915
raise exceptions.InterfaceError('pool is closing')
893916
self._check_init()
894917

895-
if timeout is None:
918+
# Use pool_timeout as fallback if no timeout specified
919+
effective_timeout = timeout or self._pool_timeout
920+
921+
if effective_timeout is None:
896922
return await _acquire_impl()
897923
else:
898924
return await compat.wait_for(
899-
_acquire_impl(), timeout=timeout)
925+
_acquire_impl(),
926+
timeout=effective_timeout
927+
)
900928

901929
async def release(self, connection, *, timeout=None):
902930
"""Release a database connection back to the pool.
@@ -906,7 +934,8 @@ async def release(self, connection, *, timeout=None):
906934
:param float timeout:
907935
A timeout for releasing the connection. If not specified, defaults
908936
to the timeout provided in the corresponding call to the
909-
:meth:`Pool.acquire() <asyncpg.pool.Pool.acquire>` method.
937+
:meth:`Pool.acquire() <asyncpg.pool.Pool.acquire>` method, or
938+
to the pool's *pool_timeout* if no acquire timeout was set.
910939
911940
.. versionchanged:: 0.14.0
912941
Added the *timeout* parameter.
@@ -929,7 +958,7 @@ async def release(self, connection, *, timeout=None):
929958

930959
ch = connection._holder
931960
if timeout is None:
932-
timeout = ch._timeout
961+
timeout = ch._timeout or self._pool_timeout
933962

934963
# Use asyncio.shield() to guarantee that task cancellation
935964
# does not prevent the connection from being returned to the
@@ -1065,26 +1094,32 @@ async def __aexit__(
10651094
self.done = True
10661095
con = self.connection
10671096
self.connection = None
1068-
await self.pool.release(con)
1097+
# Use the acquire timeout if set, otherwise fall back to pool_timeout
1098+
release_timeout = self.timeout or self.pool._pool_timeout
1099+
await self.pool.release(con, timeout=release_timeout)
10691100

10701101
def __await__(self):
10711102
self.done = True
10721103
return self.pool._acquire(self.timeout).__await__()
10731104

10741105

1075-
def create_pool(dsn=None, *,
1076-
min_size=10,
1077-
max_size=10,
1078-
max_queries=50000,
1079-
max_inactive_connection_lifetime=300.0,
1080-
connect=None,
1081-
setup=None,
1082-
init=None,
1083-
reset=None,
1084-
loop=None,
1085-
connection_class=connection.Connection,
1086-
record_class=protocol.Record,
1087-
**connect_kwargs):
1106+
def create_pool(
1107+
dsn=None,
1108+
*,
1109+
min_size=10,
1110+
max_size=10,
1111+
max_queries=50000,
1112+
max_inactive_connection_lifetime=300.0,
1113+
pool_timeout=None,
1114+
connect=None,
1115+
setup=None,
1116+
init=None,
1117+
reset=None,
1118+
loop=None,
1119+
connection_class=connection.Connection,
1120+
record_class=protocol.Record,
1121+
**connect_kwargs,
1122+
):
10881123
r"""Create a connection pool.
10891124
10901125
Can be used either with an ``async with`` block:
@@ -1161,6 +1196,11 @@ def create_pool(dsn=None, *,
11611196
Number of seconds after which inactive connections in the
11621197
pool will be closed. Pass ``0`` to disable this mechanism.
11631198
1199+
:param float pool_timeout:
1200+
Default timeout for pool operations (connection acquire and release).
1201+
If not specified, pool operations may hang indefinitely. Individual
1202+
operations can override this with their own timeout parameters.
1203+
11641204
:param coroutine connect:
11651205
A coroutine that is called instead of
11661206
:func:`~asyncpg.connection.connect` whenever the pool needs to make a
@@ -1238,6 +1278,7 @@ def create_pool(dsn=None, *,
12381278
min_size=min_size,
12391279
max_size=max_size,
12401280
max_queries=max_queries,
1281+
pool_timeout=pool_timeout,
12411282
loop=loop,
12421283
connect=connect,
12431284
setup=setup,

tests/test_pool.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929

3030
class SlowResetConnection(pg_connection.Connection):
3131
"""Connection class to simulate races with Connection.reset()."""
32-
async def reset(self, *, timeout=None):
32+
async def _reset(self):
3333
await asyncio.sleep(0.2)
34-
return await super().reset(timeout=timeout)
34+
return await super()._reset()
3535

3636

3737
class SlowCancelConnection(pg_connection.Connection):
@@ -1004,6 +1004,46 @@ async def worker():
10041004
conn = await pool.acquire(timeout=0.1)
10051005
await pool.release(conn)
10061006

1007+
async def test_pool_timeout_acquire_timeout(self):
1008+
pool = await self.create_pool(
1009+
database='postgres',
1010+
min_size=1,
1011+
max_size=1, # Only 1 connection to force timeout
1012+
pool_timeout=0.1
1013+
)
1014+
1015+
# First acquire the only connection
1016+
conn1 = await pool.acquire()
1017+
1018+
# Now try to acquire another - should timeout due to pool_timeout
1019+
start_time = time.monotonic()
1020+
with self.assertRaises(asyncio.TimeoutError):
1021+
await pool.acquire()
1022+
end_time = time.monotonic()
1023+
1024+
self.assertLess(end_time - start_time, 0.2)
1025+
1026+
await pool.release(conn1)
1027+
await pool.close()
1028+
1029+
async def test_pool_timeout_release_with_slow_reset(self):
1030+
pool = await self.create_pool(
1031+
database='postgres',
1032+
min_size=1,
1033+
max_size=1,
1034+
pool_timeout=0.1,
1035+
connection_class=SlowResetConnection,
1036+
)
1037+
1038+
start_time = time.monotonic()
1039+
with self.assertRaises(asyncio.TimeoutError):
1040+
conn = await pool.acquire()
1041+
await pool.release(conn)
1042+
end_time = time.monotonic()
1043+
1044+
self.assertLess(end_time - start_time, 0.2)
1045+
await pool.close()
1046+
10071047

10081048
@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster')
10091049
class TestPoolReconnectWithTargetSessionAttrs(tb.ClusterTestCase):

0 commit comments

Comments
 (0)