Skip to content

Commit 83c2c3c

Browse files
authored
Merge pull request #5 from dapper91/dev
- socket_timeout contex manager works with ssl sockets. - graceful_timeout default set to 0. - connection release bug fixed.
2 parents 609b740 + 0362b85 commit 83c2c3c

File tree

12 files changed

+293
-193
lines changed

12 files changed

+293
-193
lines changed

CHANGELOG.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
Changelog
22
=========
33

4+
0.2.0 (2023-04-19)
5+
------------------
6+
7+
- socket_timeout contex manager works with ssl sockets.
8+
- graceful_timeout default set to 0.
9+
- connection release bug fixed.
10+
411

512
0.1.1 (2023-03-17)
613
------------------

generic_connection_pool/asyncio.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,14 @@ async def release(self, conn: ConnectionT, endpoint: EndpointT) -> None:
155155
dispose_batch_size = self._dispose_batch_size or int(math.log2(self._pool_size + 1)) + 1
156156
await self._collect_disposable_connections(dispose_batch_size)
157157

158-
async def close(self, graceful_timeout: Optional[float] = None, timeout: Optional[float] = None) -> None:
158+
async def close(self, graceful_timeout: float = 0.0, timeout: Optional[float] = None) -> None:
159159
"""
160160
Closes the connection pool.
161161
162-
:param graceful_timeout: timeout within which the pool waits all acquired connection to be released
162+
:param graceful_timeout: timeout within which the pool waits for all acquired connection to be released
163163
:param timeout: timeout after which the pool closes all connection despite they are released or not
164164
"""
165165

166-
if graceful_timeout is None:
167-
graceful_timeout = timeout
168-
169166
if graceful_timeout is not None and timeout is not None:
170167
assert timeout >= graceful_timeout, "timeout can't be less than graceful_timeout"
171168

@@ -403,7 +400,7 @@ async def _close_connections(
403400
try:
404401
while released:
405402
conn_info = released[-1]
406-
await self._dispose_connection(conn_info, timeout=graceful_timer.remains)
403+
await self._dispose_connection(conn_info, timeout=global_timer.remains)
407404
released.pop()
408405
except asyncio.TimeoutError:
409406
await asyncio.shield(self._return_released_conns(released))
@@ -422,5 +419,5 @@ async def _return_released_conns(self, released: List[ConnectionInfo[EndpointT,
422419
for conn_info in released:
423420
pool = self._pools[conn_info.endpoint]
424421
pool.queue[conn_info.conn] = conn_info
425-
pool.access_queue.remove((conn_info.accessed_at, conn_info.conn))
422+
pool.access_queue.push((conn_info.accessed_at, conn_info.conn))
426423
self._pool_size += 1

generic_connection_pool/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def remains(self) -> Optional[float]:
4242
if self._timeout is None:
4343
return None
4444

45-
return self._timeout - self.elapsed
45+
return max(0.0, self._timeout - self.elapsed)
4646

4747
@property
4848
def timedout(self) -> bool:

generic_connection_pool/contrib/socket.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import contextlib
12
import socket
23
from ipaddress import IPv4Address, IPv6Address
34
from ssl import SSLContext, SSLSocket
4-
from typing import Optional, Tuple, Union
5+
from typing import Generator, Optional, Tuple, Union
56

67
from generic_connection_pool.threding import BaseConnectionManager
78

@@ -11,6 +12,25 @@
1112
TcpEndpoint = Tuple[IpAddress, Port]
1213

1314

15+
@contextlib.contextmanager
16+
def socket_timeout(sock: socket.socket, timeout: Optional[float]) -> Generator[None, None, None]:
17+
if timeout is None:
18+
yield
19+
return
20+
21+
orig_timeout = sock.gettimeout()
22+
sock.settimeout(max(timeout, 1e-6)) # if timeout is 0 set it a small value to prevent non-blocking socket mode
23+
try:
24+
yield
25+
except OSError as e:
26+
if 'timed out' in str(e):
27+
raise TimeoutError
28+
else:
29+
raise
30+
finally:
31+
sock.settimeout(orig_timeout)
32+
33+
1434
class TcpSocketConnectionManager(BaseConnectionManager[TcpEndpoint, socket.socket]):
1535
"""
1636
TCP socket connection manager.
@@ -28,17 +48,12 @@ def create(self, endpoint: TcpEndpoint, timeout: Optional[float] = None) -> sock
2848

2949
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
3050

31-
orig_timeout = sock.gettimeout()
32-
sock.settimeout(timeout)
33-
try:
51+
with socket_timeout(sock, timeout):
3452
sock.connect((str(addr), port))
35-
finally:
36-
sock.settimeout(orig_timeout)
3753

3854
return sock
3955

4056
def dispose(self, endpoint: TcpEndpoint, conn: socket.socket, timeout: Optional[float] = None) -> None:
41-
conn.settimeout(timeout)
4257
try:
4358
conn.shutdown(socket.SHUT_RDWR)
4459
except OSError:
@@ -62,17 +77,13 @@ def create(self, endpoint: SslEndpoint, timeout: Optional[float] = None) -> SSLS
6277
hostname, port = endpoint
6378

6479
sock = self._ssl.wrap_socket(socket.socket(type=socket.SOCK_STREAM), server_hostname=hostname)
65-
orig_timeout = sock.gettimeout()
66-
sock.settimeout(timeout)
67-
try:
80+
81+
with socket_timeout(sock, timeout):
6882
sock.connect((hostname, port))
69-
finally:
70-
sock.settimeout(orig_timeout)
7183

7284
return sock
7385

7486
def dispose(self, endpoint: SslEndpoint, conn: SSLSocket, timeout: Optional[float] = None) -> None:
75-
conn.settimeout(timeout)
7687
try:
7788
conn.shutdown(socket.SHUT_RDWR)
7889
except OSError:

generic_connection_pool/contrib/socket_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class TcpStreamConnectionManager(BaseConnectionManager[TcpStreamEndpoint, TcpStr
4949
TCP stream connection manager.
5050
"""
5151

52-
def __init__(self, ssl: Union[None, bool, SSLContext]):
52+
def __init__(self, ssl: Union[None, bool, SSLContext] = None):
5353
self._ssl = ssl
5454

5555
async def create(self, endpoint: TcpStreamEndpoint) -> TcpStream:

generic_connection_pool/threding.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,14 @@ def release(self, conn: ConnectionT, endpoint: EndpointT) -> None:
159159
dispose_batch_size = self._dispose_batch_size or int(math.log2(self._pool_size + 1)) + 1
160160
self._collect_disposable_connections(dispose_batch_size)
161161

162-
def close(self, graceful_timeout: Optional[float] = None, timeout: Optional[float] = None) -> None:
162+
def close(self, graceful_timeout: float = 0.0, timeout: Optional[float] = None) -> None:
163163
"""
164164
Closes the connection pool.
165165
166-
:param graceful_timeout: timeout within which the pool waits all acquired connection to be released
166+
:param graceful_timeout: timeout within which the pool waits for all acquired connection to be released
167167
:param timeout: timeout after which the pool closes all connection despite they are released or not
168168
"""
169169

170-
if graceful_timeout is None:
171-
graceful_timeout = timeout
172-
173170
if graceful_timeout is not None and timeout is not None:
174171
assert timeout >= graceful_timeout, "timeout can't be less than graceful_timeout"
175172

@@ -401,7 +398,7 @@ def _close_connections(self, graceful_timeout: Optional[float], timeout: Optiona
401398
try:
402399
while released:
403400
conn_info = released[-1]
404-
self._dispose_connection(conn_info, timeout=graceful_timer.remains)
401+
self._dispose_connection(conn_info, timeout=global_timer.remains)
405402
released.pop()
406403
except TimeoutError:
407404
self._return_released_conns(released)
@@ -420,5 +417,5 @@ def _return_released_conns(self, released: List[ConnectionInfo[EndpointT, Connec
420417
for conn_info in released:
421418
pool = self._pools[conn_info.endpoint]
422419
pool.queue[conn_info.conn] = conn_info
423-
pool.access_queue.remove((conn_info.accessed_at, conn_info.conn))
420+
pool.access_queue.push((conn_info.accessed_at, conn_info.conn))
424421
self._pool_size += 1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "generic-connection-pool"
3-
version = "0.1.1"
3+
version = "0.2.0"
44
description = "generic connection pool"
55
authors = ["Dmitry Pershin <[email protected]>"]
66
license = "Unlicense"

tests/conftest.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,34 @@
11
import random
22
from pathlib import Path
3+
from typing import Generator
34

45
import pytest
56

67

7-
@pytest.fixture
8+
@pytest.fixture(scope='session')
89
def test_dir() -> Path:
910
return Path(__file__).parent
1011

1112

12-
@pytest.fixture
13+
@pytest.fixture(scope='session')
1314
def resource_dir(test_dir) -> Path:
1415
return test_dir / 'resources'
1516

1617

1718
@pytest.fixture(autouse=True)
18-
def init_random():
19+
def init_random() -> None:
1920
random.seed(0)
2021

2122

2223
@pytest.fixture(scope='session')
23-
def sleep_delay(pytestconfig) -> float:
24-
return pytestconfig.getoption('--sleep-delay', 0.05)
24+
def delay(pytestconfig) -> float:
25+
return pytestconfig.getoption('--delay', 0.05)
26+
27+
28+
@pytest.fixture(scope='session')
29+
def port_gen() -> Generator[int, None, None]:
30+
def gen():
31+
for port in range(10000, 65535):
32+
yield port
33+
34+
return gen()

tests/contrib/test_async_socket_manager.py

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,72 @@
22
import ssl
33
from ipaddress import IPv4Address
44
from pathlib import Path
5-
from typing import Tuple
5+
from typing import AsyncGenerator, Generator, Optional, Tuple
66

77
import pytest
8-
import pytest_asyncio.plugin
98

109
from generic_connection_pool.asyncio import ConnectionPool
1110
from generic_connection_pool.contrib.socket_async import TcpSocketConnectionManager, TcpStreamConnectionManager
1211

1312

14-
@pytest.fixture
15-
async def tcp_server(request: pytest_asyncio.plugin.SubRequest, resource_dir: Path):
16-
params = getattr(request, 'param', {})
17-
user_ssl = params.get('use_ssl', False)
18-
19-
if user_ssl:
20-
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
21-
context.load_cert_chain(resource_dir / 'ssl.cert', resource_dir / 'ssl.key')
22-
else:
23-
context = None
24-
25-
hostname, addr, port = 'localhost', IPv4Address('127.0.0.1'), 10000
26-
27-
async def echo(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
28-
data = await reader.read(1024)
13+
class TCPServer:
14+
@staticmethod
15+
async def echo_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
16+
data = await reader.read(1500)
2917
writer.write(data)
3018
await writer.drain()
3119
writer.close()
3220
await writer.wait_closed()
3321

34-
async with (server := await asyncio.start_server(echo, host=hostname, port=port, ssl=context, reuse_port=True)):
35-
server_task = asyncio.create_task(server.serve_forever())
36-
yield hostname, addr, port
37-
server_task.cancel()
38-
try:
39-
await server_task
40-
except asyncio.CancelledError:
41-
pass
22+
def __init__(self, hostname: str, port: int, ssl_ctx: Optional[ssl.SSLContext] = None):
23+
self._hostname = hostname
24+
self._port = port
25+
self._ssl_ctx = ssl_ctx
26+
self._server_task: Optional[asyncio.Task[None]] = None
27+
28+
async def start(self) -> None:
29+
server = await asyncio.start_server(
30+
self.echo_handler,
31+
host=self._hostname,
32+
port=self._port,
33+
ssl=self._ssl_ctx,
34+
reuse_port=True,
35+
)
36+
self._server_task = asyncio.create_task(server.serve_forever())
37+
38+
async def stop(self) -> None:
39+
if (server_task := self._server_task) is not None:
40+
server_task.cancel()
41+
try:
42+
await server_task
43+
except asyncio.CancelledError:
44+
pass
45+
46+
47+
@pytest.fixture
48+
async def tcp_server(port_gen: Generator[int, None, None]) -> AsyncGenerator[Tuple[IPv4Address, int], None]:
49+
addr, port = IPv4Address('127.0.0.1'), next(port_gen)
50+
server = TCPServer(str(addr), port)
51+
await server.start()
52+
yield addr, port
53+
await server.stop()
54+
55+
56+
@pytest.fixture
57+
async def ssl_server(resource_dir: Path, port_gen: Generator[int, None, None]) -> AsyncGenerator[Tuple[str, int], None]:
58+
hostname, port = 'localhost', next(port_gen)
59+
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
60+
ssl_ctx.load_cert_chain(resource_dir / 'ssl.cert', resource_dir / 'ssl.key')
61+
62+
server = TCPServer(hostname, port, ssl_ctx=ssl_ctx)
63+
await server.start()
64+
yield hostname, port
65+
await server.stop()
4266

4367

44-
async def test_tcp_socket_manager(tcp_server: Tuple[str, IPv4Address, int]):
68+
async def test_tcp_socket_manager(tcp_server: Tuple[IPv4Address, int]):
4569
loop = asyncio.get_running_loop()
46-
hostname, addr, port = tcp_server
70+
addr, port = tcp_server
4771

4872
pool = ConnectionPool(TcpSocketConnectionManager())
4973
async with pool.connection((addr, port)) as sock:
@@ -55,20 +79,23 @@ async def test_tcp_socket_manager(tcp_server: Tuple[str, IPv4Address, int]):
5579
await pool.close()
5680

5781

58-
@pytest.mark.parametrize(
59-
'use_ssl, tcp_server', [
60-
(True, {'use_ssl': True}),
61-
(False, {'use_ssl': False}),
62-
],
63-
indirect=['tcp_server'],
64-
)
65-
async def test_ssl_stream_manager(resource_dir: Path, use_ssl: bool, tcp_server: Tuple[str, IPv4Address, int]):
66-
hostname, addr, port = tcp_server
67-
68-
if use_ssl:
69-
ssl_context = ssl.create_default_context(cafile=resource_dir / 'ssl.cert')
70-
else:
71-
ssl_context = None
82+
async def test_tcp_stream_manager(resource_dir: Path, tcp_server: Tuple[IPv4Address, int]):
83+
addr, port = tcp_server
84+
85+
pool = ConnectionPool(TcpStreamConnectionManager(ssl=None))
86+
async with pool.connection((str(addr), port)) as (reader, writer):
87+
request = b'test'
88+
writer.write(request)
89+
await writer.drain()
90+
response = await reader.read()
91+
assert response == request
92+
93+
await pool.close()
94+
95+
96+
async def test_ssl_stream_manager(resource_dir: Path, ssl_server: Tuple[str, int]):
97+
hostname, port = ssl_server
98+
ssl_context = ssl.create_default_context(cafile=resource_dir / 'ssl.cert')
7299

73100
pool = ConnectionPool(TcpStreamConnectionManager(ssl_context))
74101
async with pool.connection((hostname, port)) as (reader, writer):

0 commit comments

Comments
 (0)