Skip to content

Commit 744010d

Browse files
authored
Merge pull request #2 from dapper91/contrib-fix
contrib fixes
2 parents 767f255 + 083102e commit 744010d

File tree

11 files changed

+254
-14
lines changed

11 files changed

+254
-14
lines changed

examples/connection_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@ def create(self, endpoint: SslEndpoint, timeout: Optional[float] = None) -> Conn
2222
hostname, port = endpoint
2323

2424
sock = self._ssl.wrap_socket(socket.socket(type=socket.SOCK_STREAM), server_hostname=hostname)
25+
26+
orig_timeout = sock.gettimeout()
2527
sock.settimeout(timeout)
26-
sock.connect((hostname, port))
28+
try:
29+
sock.connect((hostname, port))
30+
finally:
31+
sock.settimeout(orig_timeout)
2732

2833
return sock
2934

generic_connection_pool/asyncio.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ async def connection(
119119
timeout: Optional[float] = None,
120120
) -> AsyncGenerator[ConnectionT, None]:
121121
"""
122-
Acquires a connection form the pool.
122+
Acquires a connection from the pool.
123123
124124
:param endpoint: connection endpoint
125125
:param timeout: number of seconds to wait. If timeout is reached :py:class:`asyncio.TimeoutError` is raised.
@@ -134,7 +134,7 @@ async def connection(
134134

135135
async def acquire(self, endpoint: EndpointT, timeout: Optional[float] = None) -> ConnectionT:
136136
"""
137-
Acquires a connection form the pool.
137+
Acquires a connection from the pool.
138138
139139
:param endpoint: connection endpoint
140140
:param timeout: number of seconds to wait. If timeout is reached :py:class:`asyncio.TimeoutError` is raised.
@@ -262,7 +262,7 @@ async def _dispose_connection(
262262
logger.error("connection disposal timed-out: %s", conn_info.endpoint)
263263
raise
264264
except Exception as e:
265-
logger.exception("connection disposal failed: %s", e)
265+
logger.error("connection disposal failed: %s", e)
266266
return False
267267

268268
logger.debug("connection disposed: %s", conn_info.endpoint)

generic_connection_pool/contrib/psycopg2.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Mapping, Optional
1+
from typing import List, Mapping, Optional
22

33
import psycopg2.extensions
44

55
from generic_connection_pool.threding import BaseConnectionManager
66

77
DbEndpoint = str
88
Connection = psycopg2.extensions.connection
9-
DsnParameters = Mapping[str, Mapping[str, str]]
9+
DsnParameters = Mapping[DbEndpoint, Mapping[str, str]]
1010

1111

1212
class DbConnectionManager(BaseConnectionManager[DbEndpoint, Connection]):
@@ -28,8 +28,14 @@ def dispose(self, endpoint: DbEndpoint, conn: Connection, timeout: Optional[floa
2828
def check_aliveness(self, endpoint: DbEndpoint, conn: Connection, timeout: Optional[float] = None) -> bool:
2929
try:
3030
with conn.cursor() as cur:
31-
cur.execute("SELECT 1;")
31+
query: List[str] = []
32+
if timeout is not None:
33+
query.append(f"SET statement_timeout = '{int(timeout)}s';")
34+
35+
query.append("SELECT 1;")
36+
cur.execute(''.join(query))
3237
cur.fetchone()
38+
3339
except (psycopg2.Error, OSError):
3440
return False
3541

generic_connection_pool/contrib/socket.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,13 @@ def create(self, endpoint: TcpEndpoint, timeout: Optional[float] = None) -> sock
2727
raise RuntimeError("unsupported address version type: %s", addr.version)
2828

2929
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
30+
31+
orig_timeout = sock.gettimeout()
3032
sock.settimeout(timeout)
31-
sock.connect((str(addr), port))
33+
try:
34+
sock.connect((str(addr), port))
35+
finally:
36+
sock.settimeout(orig_timeout)
3237

3338
return sock
3439

@@ -57,8 +62,12 @@ def create(self, endpoint: SslEndpoint, timeout: Optional[float] = None) -> SSLS
5762
hostname, port = endpoint
5863

5964
sock = self._ssl.wrap_socket(socket.socket(type=socket.SOCK_STREAM), server_hostname=hostname)
65+
orig_timeout = sock.gettimeout()
6066
sock.settimeout(timeout)
61-
sock.connect((hostname, port))
67+
try:
68+
sock.connect((hostname, port))
69+
finally:
70+
sock.settimeout(orig_timeout)
6271

6372
return sock
6473

generic_connection_pool/contrib/socket_async.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,22 @@ def __init__(self, ssl: Union[None, bool, SSLContext]):
5454

5555
async def create(self, endpoint: TcpStreamEndpoint) -> TcpStream:
5656
hostname, port = endpoint
57+
server_hostname = hostname if self._ssl is not None else None
58+
5759
reader, writer = await asyncio.open_connection(
5860
hostname,
5961
port,
60-
server_hostname=hostname,
62+
server_hostname=server_hostname,
6163
ssl=self._ssl,
6264
)
6365

6466
return reader, writer
6567

6668
async def dispose(self, endpoint: TcpStreamEndpoint, conn: TcpStream) -> None:
6769
reader, writer = conn
68-
writer.write_eof()
70+
if writer.can_write_eof():
71+
writer.write_eof()
72+
6973
writer.close()
7074
await writer.wait_closed()
7175

generic_connection_pool/threding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(
123123
@contextlib.contextmanager
124124
def connection(self, endpoint: EndpointT, timeout: Optional[float] = None) -> Generator[ConnectionT, None, None]:
125125
"""
126-
Acquires a connection form the pool.
126+
Acquires a connection from the pool.
127127
128128
:param endpoint: connection endpoint
129129
:param timeout: number of seconds to wait. If timeout is reached :py:class:`TimeoutError` is raised.
@@ -138,7 +138,7 @@ def connection(self, endpoint: EndpointT, timeout: Optional[float] = None) -> Ge
138138

139139
def acquire(self, endpoint: EndpointT, timeout: Optional[float] = None) -> ConnectionT:
140140
"""
141-
Acquires a connection form the pool.
141+
Acquires a connection from the pool.
142142
143143
:param endpoint: connection endpoint
144144
:param timeout: number of seconds to wait. If timeout is reached :py:class:`TimeoutError` is raised.
@@ -261,7 +261,7 @@ def _dispose_connection(self, conn_info: ConnectionInfo[EndpointT, ConnectionT],
261261
logger.error("connection disposal timed-out: %s", conn_info.endpoint)
262262
raise
263263
except Exception as e:
264-
logger.exception("connection disposal failed: %s", e)
264+
logger.error("connection disposal failed: %s", e)
265265
return False
266266

267267
logger.debug("connection disposed: %s", conn_info.endpoint)

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
import random
2+
from pathlib import Path
23

34
import pytest
45

56

7+
@pytest.fixture
8+
def test_dir() -> Path:
9+
return Path(__file__).parent
10+
11+
12+
@pytest.fixture
13+
def resource_dir(test_dir) -> Path:
14+
return test_dir / 'resources'
15+
16+
617
@pytest.fixture(autouse=True)
718
def init_random():
819
random.seed(0)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import asyncio
2+
import ssl
3+
from ipaddress import IPv4Address
4+
from pathlib import Path
5+
from typing import Tuple
6+
7+
import pytest
8+
import pytest_asyncio.plugin
9+
10+
from generic_connection_pool.asyncio import ConnectionPool
11+
from generic_connection_pool.contrib.socket_async import TcpSocketConnectionManager, TcpStreamConnectionManager
12+
13+
14+
@pytest.fixture
15+
async def tcp_server(request: pytest_asyncio.plugin.SubRequest, resource_dir: Path):
16+
params = getattr(request, 'param', {})
17+
user_ssl = params.get('use_ssl', False)
18+
19+
if user_ssl:
20+
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
21+
context.load_cert_chain(resource_dir / 'ssl.cert', resource_dir / 'ssl.key')
22+
else:
23+
context = None
24+
25+
hostname, addr, port = 'localhost', IPv4Address('127.0.0.1'), 10000
26+
27+
async def echo(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
28+
data = await reader.read(1024)
29+
writer.write(data)
30+
await writer.drain()
31+
writer.close()
32+
await writer.wait_closed()
33+
34+
async with (server := await asyncio.start_server(echo, host=hostname, port=port, ssl=context, reuse_port=True)):
35+
server_task = asyncio.create_task(server.serve_forever())
36+
yield hostname, addr, port
37+
server_task.cancel()
38+
try:
39+
await server_task
40+
except asyncio.CancelledError:
41+
pass
42+
43+
44+
async def test_tcp_socket_manager(tcp_server: Tuple[str, IPv4Address, int]):
45+
loop = asyncio.get_running_loop()
46+
hostname, addr, port = tcp_server
47+
48+
pool = ConnectionPool(TcpSocketConnectionManager())
49+
async with pool.connection((addr, port)) as sock:
50+
request = b'test'
51+
await loop.sock_sendall(sock, request)
52+
response = await loop.sock_recv(sock, len(request))
53+
assert response == request
54+
55+
await pool.close()
56+
57+
58+
@pytest.mark.parametrize(
59+
'use_ssl, tcp_server', [
60+
(True, {'use_ssl': True}),
61+
(False, {'use_ssl': False}),
62+
],
63+
indirect=['tcp_server'],
64+
)
65+
async def test_ssl_stream_manager(resource_dir: Path, use_ssl: bool, tcp_server: Tuple[str, IPv4Address, int]):
66+
hostname, addr, port = tcp_server
67+
68+
if use_ssl:
69+
ssl_context = ssl.create_default_context(cafile=resource_dir / 'ssl.cert')
70+
else:
71+
ssl_context = None
72+
73+
pool = ConnectionPool(TcpStreamConnectionManager(ssl_context))
74+
async with pool.connection((hostname, port)) as (reader, writer):
75+
request = b'test'
76+
writer.write(request)
77+
await writer.drain()
78+
response = await reader.read()
79+
assert response == request
80+
81+
await pool.close()
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import socketserver
2+
import ssl
3+
import threading
4+
from ipaddress import IPv4Address
5+
from pathlib import Path
6+
from typing import Tuple
7+
8+
import pytest
9+
import pytest_asyncio.plugin
10+
11+
from generic_connection_pool.contrib.socket import SslSocketConnectionManager, TcpSocketConnectionManager
12+
from generic_connection_pool.threding import ConnectionPool
13+
14+
15+
@pytest.fixture
16+
async def tcp_server(request: pytest_asyncio.plugin.SubRequest, resource_dir: Path):
17+
params = getattr(request, 'param', {})
18+
user_ssl = params.get('use_ssl', False)
19+
20+
hostname, addr, port = 'localhost', IPv4Address('127.0.0.1'), 10000
21+
22+
class RequestHandler(socketserver.BaseRequestHandler):
23+
def handle(self):
24+
data = self.request.recv(1024)
25+
self.request.sendall(data)
26+
27+
class TCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
28+
allow_reuse_address = True
29+
30+
class SSLServer(TCPServer):
31+
def get_request(self):
32+
socket, addr = super().get_request()
33+
ssl_socket = ssl.wrap_socket(
34+
socket,
35+
server_side=True,
36+
keyfile=resource_dir / 'ssl.key',
37+
certfile=resource_dir / 'ssl.cert',
38+
)
39+
return ssl_socket, addr
40+
41+
Server = SSLServer if user_ssl else TCPServer
42+
43+
server = Server((hostname, port), RequestHandler)
44+
with server:
45+
server_thread = threading.Thread(target=server.serve_forever)
46+
server_thread.daemon = True
47+
server_thread.start()
48+
yield hostname, addr, port
49+
server.shutdown()
50+
51+
52+
@pytest.mark.parametrize('tcp_server', [{'use_ssl': False}], indirect=['tcp_server'])
53+
def test_tcp_socket_manager(tcp_server: Tuple[str, IPv4Address, int]):
54+
hostname, addr, port = tcp_server
55+
56+
pool = ConnectionPool(TcpSocketConnectionManager())
57+
with pool.connection((addr, port)) as sock:
58+
request = b'test'
59+
sock.sendall(request)
60+
response = sock.recv(len(request))
61+
assert response == request
62+
63+
pool.close()
64+
65+
66+
@pytest.mark.parametrize('tcp_server', [{'use_ssl': True}], indirect=['tcp_server'])
67+
def test_ssl_socket_manager(resource_dir: Path, tcp_server: Tuple[str, IPv4Address, int]):
68+
hostname, addr, port = tcp_server
69+
ssl_context = ssl.create_default_context(cafile=resource_dir / 'ssl.cert')
70+
71+
pool = ConnectionPool(SslSocketConnectionManager(ssl_context))
72+
with pool.connection((hostname, port)) as sock:
73+
request = b'test'
74+
sock.sendall(request)
75+
response = sock.recv(len(request))
76+
assert response == request
77+
78+
pool.close()

tests/resources/ssl.cert

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
-----BEGIN CERTIFICATE-----
2+
MIIC2DCCAcACCQCPGAEZ/izmXDANBgkqhkiG9w0BAQsFADAuMQswCQYDVQQGEwJV
3+
UzELMAkGA1UECAwCQ0ExEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0yMzAzMTYxNjMx
4+
NDhaFw0yNTEyMTAxNjMxNDhaMC4xCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJDQTES
5+
MBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC
6+
AQEAvLSFc4DafQ2zrU+05P0PGlGjDeab9zXfVZmbEziNQ/8V+RtvmwFw9aENJtRv
7+
IsehZgz505hmTbBtsVPGA3F9MUjBRyMD25ztPxMTA8jQ5xtKhXECErVQ2AmgHHSD
8+
mEWa9S7/Mcu8Ld5dL2lqFsJTG/GjlHcmRMzCHfNKrDQuXeucsNm1mWRlMZ/eK/bj
9+
n+RkeBAKaAN4khVV7UAFeIaISTuD4cUVfv14spB2UV3nmezATXmbR9YCOM91us8Y
10+
rJP3XMJ/vgWVRofDqGCZM9Fs5WlpSbl4/Fei/zDIS/5p6moukpEWNRuV3pN9WI+Q
11+
ciwWbCMwITOmMCXVelRuPpXefwIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQBOb+jT
12+
Tseengdf+J7XI4CYtrPoCf9OPZrZHRZcU/d7y04OZpiBin7thBkOsfhESVRABpKN
13+
UhAWtfgnF43Pbxm1l6jg6SbTIwaQ8kXXu+Cx/RaVVCAek5GDStFNSj3ZMsnGrFOO
14+
wtsYoHBLSO11mOu+VcfCbzUqGBaGGLbXySSBx0uwkcSo+Qa5NVXLk54mmWrg84nE
15+
HjdbWTVj+UrSWkJ/9G1bjl1QeKfT9gwusSDKsLiN5QUCVKFjNaARxJ5SDbrxaTmL
16+
1391HGjt1zclDmnypokj56Wkj+HOkUKohC11b1Kgsxl8/Ykyntz89OAj0Vv17xYa
17+
LjQZuCWYlln6S9+S
18+
-----END CERTIFICATE-----

0 commit comments

Comments
 (0)