Skip to content

Commit 609b740

Browse files
authored
Merge pull request #3 from dapper91/dev
- Contrib bugs fixed
2 parents 9560139 + 742c484 commit 609b740

File tree

13 files changed

+262
-16
lines changed

13 files changed

+262
-16
lines changed

CHANGELOG.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@ Changelog
22
=========
33

44

5-
0.1.0 (2021-03-15)
5+
0.1.1 (2023-03-17)
6+
------------------
7+
8+
- Contrib bugs fixed
9+
10+
11+
0.1.0 (2023-03-15)
612
------------------
713

814
- Initial release

examples/connection_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@ def create(self, endpoint: SslEndpoint, timeout: Optional[float] = None) -> Conn
2222
hostname, port = endpoint
2323

2424
sock = self._ssl.wrap_socket(socket.socket(type=socket.SOCK_STREAM), server_hostname=hostname)
25+
26+
orig_timeout = sock.gettimeout()
2527
sock.settimeout(timeout)
26-
sock.connect((hostname, port))
28+
try:
29+
sock.connect((hostname, port))
30+
finally:
31+
sock.settimeout(orig_timeout)
2732

2833
return sock
2934

generic_connection_pool/asyncio.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ async def connection(
119119
timeout: Optional[float] = None,
120120
) -> AsyncGenerator[ConnectionT, None]:
121121
"""
122-
Acquires a connection form the pool.
122+
Acquires a connection from the pool.
123123
124124
:param endpoint: connection endpoint
125125
:param timeout: number of seconds to wait. If timeout is reached :py:class:`asyncio.TimeoutError` is raised.
@@ -134,7 +134,7 @@ async def connection(
134134

135135
async def acquire(self, endpoint: EndpointT, timeout: Optional[float] = None) -> ConnectionT:
136136
"""
137-
Acquires a connection form the pool.
137+
Acquires a connection from the pool.
138138
139139
:param endpoint: connection endpoint
140140
:param timeout: number of seconds to wait. If timeout is reached :py:class:`asyncio.TimeoutError` is raised.
@@ -262,7 +262,7 @@ async def _dispose_connection(
262262
logger.error("connection disposal timed-out: %s", conn_info.endpoint)
263263
raise
264264
except Exception as e:
265-
logger.exception("connection disposal failed: %s", e)
265+
logger.error("connection disposal failed: %s", e)
266266
return False
267267

268268
logger.debug("connection disposed: %s", conn_info.endpoint)

generic_connection_pool/contrib/psycopg2.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Mapping, Optional
1+
from typing import List, Mapping, Optional
22

33
import psycopg2.extensions
44

55
from generic_connection_pool.threding import BaseConnectionManager
66

77
DbEndpoint = str
88
Connection = psycopg2.extensions.connection
9-
DsnParameters = Mapping[str, Mapping[str, str]]
9+
DsnParameters = Mapping[DbEndpoint, Mapping[str, str]]
1010

1111

1212
class DbConnectionManager(BaseConnectionManager[DbEndpoint, Connection]):
@@ -28,8 +28,14 @@ def dispose(self, endpoint: DbEndpoint, conn: Connection, timeout: Optional[floa
2828
def check_aliveness(self, endpoint: DbEndpoint, conn: Connection, timeout: Optional[float] = None) -> bool:
2929
try:
3030
with conn.cursor() as cur:
31-
cur.execute("SELECT 1;")
31+
query: List[str] = []
32+
if timeout is not None:
33+
query.append(f"SET statement_timeout = '{int(timeout)}s';")
34+
35+
query.append("SELECT 1;")
36+
cur.execute(''.join(query))
3237
cur.fetchone()
38+
3339
except (psycopg2.Error, OSError):
3440
return False
3541

generic_connection_pool/contrib/socket.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,13 @@ def create(self, endpoint: TcpEndpoint, timeout: Optional[float] = None) -> sock
2727
raise RuntimeError("unsupported address version type: %s", addr.version)
2828

2929
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
30+
31+
orig_timeout = sock.gettimeout()
3032
sock.settimeout(timeout)
31-
sock.connect((str(addr), port))
33+
try:
34+
sock.connect((str(addr), port))
35+
finally:
36+
sock.settimeout(orig_timeout)
3237

3338
return sock
3439

@@ -57,8 +62,12 @@ def create(self, endpoint: SslEndpoint, timeout: Optional[float] = None) -> SSLS
5762
hostname, port = endpoint
5863

5964
sock = self._ssl.wrap_socket(socket.socket(type=socket.SOCK_STREAM), server_hostname=hostname)
65+
orig_timeout = sock.gettimeout()
6066
sock.settimeout(timeout)
61-
sock.connect((hostname, port))
67+
try:
68+
sock.connect((hostname, port))
69+
finally:
70+
sock.settimeout(orig_timeout)
6271

6372
return sock
6473

generic_connection_pool/contrib/socket_async.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,22 @@ def __init__(self, ssl: Union[None, bool, SSLContext]):
5454

5555
async def create(self, endpoint: TcpStreamEndpoint) -> TcpStream:
5656
hostname, port = endpoint
57+
server_hostname = hostname if self._ssl is not None else None
58+
5759
reader, writer = await asyncio.open_connection(
5860
hostname,
5961
port,
60-
server_hostname=hostname,
62+
server_hostname=server_hostname,
6163
ssl=self._ssl,
6264
)
6365

6466
return reader, writer
6567

6668
async def dispose(self, endpoint: TcpStreamEndpoint, conn: TcpStream) -> None:
6769
reader, writer = conn
68-
writer.write_eof()
70+
if writer.can_write_eof():
71+
writer.write_eof()
72+
6973
writer.close()
7074
await writer.wait_closed()
7175

generic_connection_pool/threding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(
123123
@contextlib.contextmanager
124124
def connection(self, endpoint: EndpointT, timeout: Optional[float] = None) -> Generator[ConnectionT, None, None]:
125125
"""
126-
Acquires a connection form the pool.
126+
Acquires a connection from the pool.
127127
128128
:param endpoint: connection endpoint
129129
:param timeout: number of seconds to wait. If timeout is reached :py:class:`TimeoutError` is raised.
@@ -138,7 +138,7 @@ def connection(self, endpoint: EndpointT, timeout: Optional[float] = None) -> Ge
138138

139139
def acquire(self, endpoint: EndpointT, timeout: Optional[float] = None) -> ConnectionT:
140140
"""
141-
Acquires a connection form the pool.
141+
Acquires a connection from the pool.
142142
143143
:param endpoint: connection endpoint
144144
:param timeout: number of seconds to wait. If timeout is reached :py:class:`TimeoutError` is raised.
@@ -261,7 +261,7 @@ def _dispose_connection(self, conn_info: ConnectionInfo[EndpointT, ConnectionT],
261261
logger.error("connection disposal timed-out: %s", conn_info.endpoint)
262262
raise
263263
except Exception as e:
264-
logger.exception("connection disposal failed: %s", e)
264+
logger.error("connection disposal failed: %s", e)
265265
return False
266266

267267
logger.debug("connection disposed: %s", conn_info.endpoint)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "generic-connection-pool"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
description = "generic connection pool"
55
authors = ["Dmitry Pershin <[email protected]>"]
66
license = "Unlicense"

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
import random
2+
from pathlib import Path
23

34
import pytest
45

56

7+
@pytest.fixture
8+
def test_dir() -> Path:
9+
return Path(__file__).parent
10+
11+
12+
@pytest.fixture
13+
def resource_dir(test_dir) -> Path:
14+
return test_dir / 'resources'
15+
16+
617
@pytest.fixture(autouse=True)
718
def init_random():
819
random.seed(0)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import asyncio
2+
import ssl
3+
from ipaddress import IPv4Address
4+
from pathlib import Path
5+
from typing import Tuple
6+
7+
import pytest
8+
import pytest_asyncio.plugin
9+
10+
from generic_connection_pool.asyncio import ConnectionPool
11+
from generic_connection_pool.contrib.socket_async import TcpSocketConnectionManager, TcpStreamConnectionManager
12+
13+
14+
@pytest.fixture
15+
async def tcp_server(request: pytest_asyncio.plugin.SubRequest, resource_dir: Path):
16+
params = getattr(request, 'param', {})
17+
user_ssl = params.get('use_ssl', False)
18+
19+
if user_ssl:
20+
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
21+
context.load_cert_chain(resource_dir / 'ssl.cert', resource_dir / 'ssl.key')
22+
else:
23+
context = None
24+
25+
hostname, addr, port = 'localhost', IPv4Address('127.0.0.1'), 10000
26+
27+
async def echo(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
28+
data = await reader.read(1024)
29+
writer.write(data)
30+
await writer.drain()
31+
writer.close()
32+
await writer.wait_closed()
33+
34+
async with (server := await asyncio.start_server(echo, host=hostname, port=port, ssl=context, reuse_port=True)):
35+
server_task = asyncio.create_task(server.serve_forever())
36+
yield hostname, addr, port
37+
server_task.cancel()
38+
try:
39+
await server_task
40+
except asyncio.CancelledError:
41+
pass
42+
43+
44+
async def test_tcp_socket_manager(tcp_server: Tuple[str, IPv4Address, int]):
45+
loop = asyncio.get_running_loop()
46+
hostname, addr, port = tcp_server
47+
48+
pool = ConnectionPool(TcpSocketConnectionManager())
49+
async with pool.connection((addr, port)) as sock:
50+
request = b'test'
51+
await loop.sock_sendall(sock, request)
52+
response = await loop.sock_recv(sock, len(request))
53+
assert response == request
54+
55+
await pool.close()
56+
57+
58+
@pytest.mark.parametrize(
59+
'use_ssl, tcp_server', [
60+
(True, {'use_ssl': True}),
61+
(False, {'use_ssl': False}),
62+
],
63+
indirect=['tcp_server'],
64+
)
65+
async def test_ssl_stream_manager(resource_dir: Path, use_ssl: bool, tcp_server: Tuple[str, IPv4Address, int]):
66+
hostname, addr, port = tcp_server
67+
68+
if use_ssl:
69+
ssl_context = ssl.create_default_context(cafile=resource_dir / 'ssl.cert')
70+
else:
71+
ssl_context = None
72+
73+
pool = ConnectionPool(TcpStreamConnectionManager(ssl_context))
74+
async with pool.connection((hostname, port)) as (reader, writer):
75+
request = b'test'
76+
writer.write(request)
77+
await writer.drain()
78+
response = await reader.read()
79+
assert response == request
80+
81+
await pool.close()

0 commit comments

Comments
 (0)