Skip to content

Commit c093fa7

Browse files
authored
Merge pull request #13 from dapper91/dev
- socket connection manager aliveness check added - shared lock bug fixed
2 parents 4d87c75 + 2015e2f commit c093fa7

File tree

14 files changed

+345
-187
lines changed

14 files changed

+345
-187
lines changed

CHANGELOG.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
Changelog
22
=========
33

4+
0.4.0 (2023-08-15)
5+
------------------
6+
7+
- socket connection manager aliveness check added
8+
- shared lock bug fixed
9+
10+
411
0.3.0 (2023-08-10)
512
------------------
613

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
"""
2+
Asyncio connection pool implementation.
3+
"""
4+
15
from .locks import SharedLock
2-
from .pool import BaseConnectionManager, ConnectionPool
6+
from .pool import BaseConnectionManager, ConnectionPool, ConnectionT, EndpointT
37

48
__all__ = [
59
'BaseConnectionManager',
610
'ConnectionPool',
11+
'ConnectionT',
12+
'EndpointT',
713
]

generic_connection_pool/asyncio/locks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,6 @@ async def _acquire_exclusive(self, timeout: Optional[float] = None) -> bool:
151151
return True
152152

153153
def _release_exclusive(self) -> None:
154-
self._lock.release()
155154
self._exclusive = False
156155
self._exclusive_owner = None
156+
self._lock.release()

generic_connection_pool/contrib/socket.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,30 @@
33
"""
44

55
import contextlib
6+
import errno
67
import socket
78
from ipaddress import IPv4Address, IPv6Address
89
from ssl import SSLContext, SSLSocket
910
from typing import Generator, Optional, Tuple, Union
1011

11-
from generic_connection_pool.threading import BaseConnectionManager
12+
from generic_connection_pool.threading import BaseConnectionManager, EndpointT
1213

1314
IpAddress = Union[IPv4Address, IPv6Address]
1415
Hostname = str
1516
Port = int
1617
TcpEndpoint = Tuple[IpAddress, Port]
1718

1819

20+
@contextlib.contextmanager
21+
def socket_nonblocking(sock: socket.socket) -> Generator[None, None, None]:
22+
orig_timeout = sock.gettimeout()
23+
sock.settimeout(0)
24+
try:
25+
yield
26+
finally:
27+
sock.settimeout(orig_timeout)
28+
29+
1930
@contextlib.contextmanager
2031
def socket_timeout(sock: socket.socket, timeout: Optional[float]) -> Generator[None, None, None]:
2132
if timeout is None:
@@ -65,6 +76,19 @@ def dispose(self, endpoint: TcpEndpoint, conn: socket.socket, timeout: Optional[
6576

6677
conn.close()
6778

79+
def check_aliveness(self, endpoint: EndpointT, conn: socket.socket, timeout: Optional[float] = None) -> bool:
80+
try:
81+
with socket_nonblocking(conn):
82+
if conn.recv(1, socket.MSG_PEEK) == b'':
83+
return False
84+
except BlockingIOError as exc:
85+
if exc.errno != errno.EAGAIN:
86+
raise
87+
except OSError:
88+
return False
89+
90+
return True
91+
6892

6993
SslEndpoint = Tuple[Hostname, Port]
7094

@@ -94,3 +118,17 @@ def dispose(self, endpoint: SslEndpoint, conn: SSLSocket, timeout: Optional[floa
94118
pass
95119

96120
conn.close()
121+
122+
def check_aliveness(self, endpoint: EndpointT, conn: socket.socket, timeout: Optional[float] = None) -> bool:
123+
try:
124+
with socket_nonblocking(conn):
125+
# peek into the plain socket since ssl socket doesn't support flags
126+
if socket.socket.recv(conn, 1, socket.MSG_PEEK) == b'':
127+
return False
128+
except BlockingIOError as exc:
129+
if exc.errno != errno.EAGAIN:
130+
raise
131+
except OSError:
132+
return False
133+
134+
return True

generic_connection_pool/contrib/socket_async.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import asyncio
6+
import errno
67
import socket
78
from ipaddress import IPv4Address, IPv6Address
89
from ssl import SSLContext
@@ -42,6 +43,18 @@ async def dispose(self, endpoint: TcpEndpoint, conn: socket.socket) -> None:
4243
conn.shutdown(socket.SHUT_RDWR)
4344
conn.close()
4445

46+
async def check_aliveness(self, endpoint: TcpEndpoint, conn: socket.socket) -> bool:
47+
try:
48+
if conn.recv(1, socket.MSG_PEEK) == b'':
49+
return False
50+
except BlockingIOError as exc:
51+
if exc.errno != errno.EAGAIN:
52+
raise
53+
except OSError:
54+
return False
55+
56+
return True
57+
4558

4659
Hostname = str
4760
TcpStreamEndpoint = Tuple[Hostname, Port]
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1-
from .pool import BaseConnectionManager, ConnectionPool
1+
"""
2+
Threading connection pool implementation.
3+
"""
4+
5+
from .pool import BaseConnectionManager, ConnectionPool, ConnectionT, EndpointT
26

37
__all__ = [
48
'BaseConnectionManager',
59
'ConnectionPool',
10+
'ConnectionT',
11+
'EndpointT',
612
]

generic_connection_pool/threading/locks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ def _acquire_exclusive(self, blocking: bool = True, timeout: Optional[float] = N
131131
return True
132132

133133
def _release_exclusive(self) -> None:
134-
self._lock.release()
135134
self._exclusive = False
136135
self._exclusive_owner = None
136+
self._lock.release()
137137

138138
def _is_owned(self) -> bool: # to be compatible with threading.Condition
139139
return self._exclusive_owner == threading.get_ident()

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
[tool.poetry]
22
name = "generic-connection-pool"
3-
version = "0.3.0"
3+
version = "0.4.0"
44
description = "generic connection pool"
55
authors = ["Dmitry Pershin <[email protected]>"]
66
license = "Unlicense"
77
readme = "README.rst"
88
homepage = "https://github.com/dapper91/generic-connection-pool"
99
repository = "https://github.com/dapper91/generic-connection-pool"
10-
documentation = "https://github.com/dapper91/generic-connection-pool"
10+
documentation = "https://generic-connection-pool.readthedocs.io"
1111
keywords = ['pool', 'connection-pool', 'asyncio', 'socket', 'tcp']
1212
classifiers = [
1313
"Development Status :: 5 - Production/Stable",
@@ -44,6 +44,7 @@ pytest-asyncio = "^0.21.1"
4444
pytest-mock = "^3.11.1"
4545
pytest-cov = "^4.1.0"
4646
types-psycopg2 = "^2.9.21"
47+
pytest-timeout = "^2.1.0"
4748

4849
[build-system]
4950
requires = ["poetry-core>=1.0.0"]

tests/contrib/test_async_socket_manager.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import socket
23
import ssl
34
from ipaddress import IPv4Address
45
from pathlib import Path
@@ -13,9 +14,10 @@
1314
class TCPServer:
1415
@staticmethod
1516
async def echo_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
16-
data = await reader.read(1500)
17-
writer.write(data)
18-
await writer.drain()
17+
while data := await reader.read(1024):
18+
writer.write(data)
19+
await writer.drain()
20+
1921
writer.close()
2022
await writer.wait_closed()
2123

@@ -31,6 +33,7 @@ async def start(self) -> None:
3133
host=self._hostname,
3234
port=self._port,
3335
ssl=self._ssl_ctx,
36+
family=socket.AF_INET,
3437
reuse_port=True,
3538
)
3639
self._server_task = asyncio.create_task(server.serve_forever())
@@ -68,44 +71,73 @@ async def ssl_server(
6871
await server.stop()
6972

7073

74+
@pytest.mark.timeout(5.0)
7175
async def test_tcp_socket_manager(tcp_server: Tuple[IPv4Address, int]):
7276
loop = asyncio.get_running_loop()
7377
addr, port = tcp_server
7478

7579
pool = ConnectionPool(TcpSocketConnectionManager())
76-
async with pool.connection((addr, port)) as sock:
77-
request = b'test'
78-
await loop.sock_sendall(sock, request)
79-
response = await loop.sock_recv(sock, len(request))
80-
assert response == request
80+
81+
attempts = 3
82+
request = b'test'
83+
for _ in range(attempts):
84+
async with pool.connection((addr, port)) as sock1:
85+
await loop.sock_sendall(sock1, request)
86+
response = await loop.sock_recv(sock1, len(request))
87+
assert response == request
88+
89+
async with pool.connection((addr, port)) as sock2:
90+
await loop.sock_sendall(sock2, request)
91+
response = await loop.sock_recv(sock2, len(request))
92+
assert response == request
8193

8294
await pool.close()
8395

8496

97+
@pytest.mark.timeout(5.0)
8598
async def test_tcp_stream_manager(resource_dir: Path, tcp_server: Tuple[IPv4Address, int]):
8699
addr, port = tcp_server
87100

88101
pool = ConnectionPool(TcpStreamConnectionManager(ssl=None))
89-
async with pool.connection((str(addr), port)) as (reader, writer):
90-
request = b'test'
91-
writer.write(request)
92-
await writer.drain()
93-
response = await reader.read()
94-
assert response == request
102+
103+
attempts = 3
104+
request = b'test'
105+
for _ in range(attempts):
106+
async with pool.connection((str(addr), port)) as (reader1, writer1):
107+
writer1.write(request)
108+
await writer1.drain()
109+
response = await reader1.read(len(request))
110+
assert response == request
111+
112+
async with pool.connection((str(addr), port)) as (reader2, writer2):
113+
writer2.write(request)
114+
await writer2.drain()
115+
response = await reader2.read(len(request))
116+
assert response == request
95117

96118
await pool.close()
97119

98120

121+
@pytest.mark.timeout(5.0)
99122
async def test_ssl_stream_manager(resource_dir: Path, ssl_server: Tuple[str, int]):
100123
hostname, port = ssl_server
101124
ssl_context = ssl.create_default_context(cafile=resource_dir / 'ssl.cert')
102125

103126
pool = ConnectionPool(TcpStreamConnectionManager(ssl_context))
104-
async with pool.connection((hostname, port)) as (reader, writer):
105-
request = b'test'
106-
writer.write(request)
107-
await writer.drain()
108-
response = await reader.read()
109-
assert response == request
127+
128+
attempts = 3
129+
request = b'test'
130+
for _ in range(attempts):
131+
async with pool.connection((hostname, port)) as (reader1, writer1):
132+
writer1.write(request)
133+
await writer1.drain()
134+
response = await reader1.read(len(request))
135+
assert response == request
136+
137+
async with pool.connection((hostname, port)) as (reader2, writer2):
138+
writer2.write(request)
139+
await writer2.drain()
140+
response = await reader2.read(len(request))
141+
assert response == request
110142

111143
await pool.close()

0 commit comments

Comments
 (0)