Skip to content

Commit 970402d

Browse files
authored
Merge pull request #15 from dapper91/dev
- aliveness checking extracted to mixins.
2 parents c093fa7 + 2cd94bf commit 970402d

File tree

4 files changed

+90
-48
lines changed

4 files changed

+90
-48
lines changed

CHANGELOG.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
Changelog
22
=========
33

4+
0.4.1 (2023-08-16)
5+
------------------
6+
7+
- aliveness checking extracted to mixins.
8+
9+
410
0.4.0 (2023-08-15)
511
------------------
612

generic_connection_pool/contrib/socket.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import socket
88
from ipaddress import IPv4Address, IPv6Address
99
from ssl import SSLContext, SSLSocket
10-
from typing import Generator, Optional, Tuple, Union
10+
from typing import Generator, Generic, Optional, Tuple, Union
1111

1212
from generic_connection_pool.threading import BaseConnectionManager, EndpointT
1313

@@ -46,7 +46,29 @@ def socket_timeout(sock: socket.socket, timeout: Optional[float]) -> Generator[N
4646
sock.settimeout(orig_timeout)
4747

4848

49-
class TcpSocketConnectionManager(BaseConnectionManager[TcpEndpoint, socket.socket]):
49+
class SocketAlivenessCheckingMixin(Generic[EndpointT]):
50+
"""
51+
Socket aliveness checking mix-in.
52+
"""
53+
54+
def check_aliveness(self, endpoint: EndpointT, conn: socket.socket, timeout: Optional[float] = None) -> bool:
55+
try:
56+
with socket_nonblocking(conn):
57+
if conn.recv(1, socket.MSG_PEEK) == b'':
58+
return False
59+
except BlockingIOError as exc:
60+
if exc.errno != errno.EAGAIN:
61+
raise
62+
except OSError:
63+
return False
64+
65+
return True
66+
67+
68+
class TcpSocketConnectionManager(
69+
SocketAlivenessCheckingMixin[TcpEndpoint],
70+
BaseConnectionManager[TcpEndpoint, socket.socket],
71+
):
5072
"""
5173
TCP socket connection manager.
5274
"""
@@ -76,10 +98,17 @@ def dispose(self, endpoint: TcpEndpoint, conn: socket.socket, timeout: Optional[
7698

7799
conn.close()
78100

79-
def check_aliveness(self, endpoint: EndpointT, conn: socket.socket, timeout: Optional[float] = None) -> bool:
101+
102+
class SslSocketAlivenessCheckingMixin(Generic[EndpointT]):
103+
"""
104+
SSL socket aliveness checking mix-in.
105+
"""
106+
107+
def check_aliveness(self, endpoint: EndpointT, conn: SSLSocket, timeout: Optional[float] = None) -> bool:
80108
try:
81109
with socket_nonblocking(conn):
82-
if conn.recv(1, socket.MSG_PEEK) == b'':
110+
# peek into the plain socket since ssl socket doesn't support flags
111+
if socket.socket.recv(conn, 1, socket.MSG_PEEK) == b'':
83112
return False
84113
except BlockingIOError as exc:
85114
if exc.errno != errno.EAGAIN:
@@ -93,7 +122,10 @@ def check_aliveness(self, endpoint: EndpointT, conn: socket.socket, timeout: Opt
93122
SslEndpoint = Tuple[Hostname, Port]
94123

95124

96-
class SslSocketConnectionManager(BaseConnectionManager[SslEndpoint, SSLSocket]):
125+
class SslSocketConnectionManager(
126+
SslSocketAlivenessCheckingMixin[SslEndpoint],
127+
BaseConnectionManager[SslEndpoint, SSLSocket],
128+
):
97129
"""
98130
SSL socket connection manager.
99131
"""
@@ -118,17 +150,3 @@ def dispose(self, endpoint: SslEndpoint, conn: SSLSocket, timeout: Optional[floa
118150
pass
119151

120152
conn.close()
121-
122-
def check_aliveness(self, endpoint: EndpointT, conn: socket.socket, timeout: Optional[float] = None) -> bool:
123-
try:
124-
with socket_nonblocking(conn):
125-
# peek into the plain socket since ssl socket doesn't support flags
126-
if socket.socket.recv(conn, 1, socket.MSG_PEEK) == b'':
127-
return False
128-
except BlockingIOError as exc:
129-
if exc.errno != errno.EAGAIN:
130-
raise
131-
except OSError:
132-
return False
133-
134-
return True

generic_connection_pool/contrib/socket_async.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,37 @@
77
import socket
88
from ipaddress import IPv4Address, IPv6Address
99
from ssl import SSLContext
10-
from typing import Tuple, Union
10+
from typing import Generic, Tuple, Union
1111

12-
from generic_connection_pool.asyncio import BaseConnectionManager
12+
from generic_connection_pool.asyncio import BaseConnectionManager, EndpointT
1313

1414
IpAddress = Union[IPv4Address, IPv6Address]
1515
Port = int
1616
TcpEndpoint = Tuple[IpAddress, Port]
1717

1818

19-
class TcpSocketConnectionManager(BaseConnectionManager[TcpEndpoint, socket.socket]):
19+
class SocketAlivenessCheckingMixin(Generic[EndpointT]):
20+
"""
21+
Nonblocking socket aliveness checking mix-in.
22+
"""
23+
24+
async def check_aliveness(self, endpoint: EndpointT, conn: socket.socket) -> bool:
25+
try:
26+
if conn.recv(1, socket.MSG_PEEK) == b'':
27+
return False
28+
except BlockingIOError as exc:
29+
if exc.errno != errno.EAGAIN:
30+
raise
31+
except OSError:
32+
return False
33+
34+
return True
35+
36+
37+
class TcpSocketConnectionManager(
38+
SocketAlivenessCheckingMixin[TcpEndpoint],
39+
BaseConnectionManager[TcpEndpoint, socket.socket],
40+
):
2041
"""
2142
TCP socket connection manager.
2243
"""
@@ -43,33 +64,40 @@ async def dispose(self, endpoint: TcpEndpoint, conn: socket.socket) -> None:
4364
conn.shutdown(socket.SHUT_RDWR)
4465
conn.close()
4566

46-
async def check_aliveness(self, endpoint: TcpEndpoint, conn: socket.socket) -> bool:
67+
68+
Hostname = str
69+
TcpStreamEndpoint = Tuple[Hostname, Port]
70+
Stream = Tuple[asyncio.StreamReader, asyncio.StreamWriter]
71+
72+
73+
class StreamAlivenessCheckingMixin(Generic[EndpointT]):
74+
"""
75+
Asynchronous stream aliveness checking mix-in.
76+
"""
77+
78+
async def check_aliveness(self, endpoint: EndpointT, conn: Stream) -> bool:
79+
reader, writer = conn
80+
4781
try:
48-
if conn.recv(1, socket.MSG_PEEK) == b'':
49-
return False
50-
except BlockingIOError as exc:
51-
if exc.errno != errno.EAGAIN:
52-
raise
82+
await reader.read(0)
5383
except OSError:
5484
return False
5585

56-
return True
57-
58-
59-
Hostname = str
60-
TcpStreamEndpoint = Tuple[Hostname, Port]
61-
TcpStream = Tuple[asyncio.StreamReader, asyncio.StreamWriter]
86+
return not writer.is_closing() and not reader.at_eof()
6287

6388

64-
class TcpStreamConnectionManager(BaseConnectionManager[TcpStreamEndpoint, TcpStream]):
89+
class TcpStreamConnectionManager(
90+
StreamAlivenessCheckingMixin[TcpStreamEndpoint],
91+
BaseConnectionManager[TcpStreamEndpoint, Stream],
92+
):
6593
"""
6694
TCP stream connection manager.
6795
"""
6896

6997
def __init__(self, ssl: Union[None, bool, SSLContext] = None):
7098
self._ssl = ssl
7199

72-
async def create(self, endpoint: TcpStreamEndpoint) -> TcpStream:
100+
async def create(self, endpoint: TcpStreamEndpoint) -> Stream:
73101
hostname, port = endpoint
74102
server_hostname = hostname if self._ssl is not None else None
75103

@@ -82,20 +110,10 @@ async def create(self, endpoint: TcpStreamEndpoint) -> TcpStream:
82110

83111
return reader, writer
84112

85-
async def dispose(self, endpoint: TcpStreamEndpoint, conn: TcpStream) -> None:
113+
async def dispose(self, endpoint: TcpStreamEndpoint, conn: Stream) -> None:
86114
reader, writer = conn
87115
if writer.can_write_eof():
88116
writer.write_eof()
89117

90118
writer.close()
91119
await writer.wait_closed()
92-
93-
async def check_aliveness(self, endpoint: TcpStreamEndpoint, conn: TcpStream) -> bool:
94-
reader, writer = conn
95-
96-
try:
97-
await reader.read(0)
98-
except OSError:
99-
return False
100-
101-
return not writer.is_closing() and not reader.at_eof()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "generic-connection-pool"
3-
version = "0.4.0"
3+
version = "0.4.1"
44
description = "generic connection pool"
55
authors = ["Dmitry Pershin <[email protected]>"]
66
license = "Unlicense"

0 commit comments

Comments
 (0)