77import socket
88from ipaddress import IPv4Address , IPv6Address
99from ssl import SSLContext
10- from typing import Tuple , Union
10+ from typing import Generic , Tuple , Union
1111
12- from generic_connection_pool .asyncio import BaseConnectionManager
12+ from generic_connection_pool .asyncio import BaseConnectionManager , EndpointT
1313
1414IpAddress = Union [IPv4Address , IPv6Address ]
1515Port = int
1616TcpEndpoint = Tuple [IpAddress , Port ]
1717
1818
19- class TcpSocketConnectionManager (BaseConnectionManager [TcpEndpoint , socket .socket ]):
19+ class SocketAlivenessCheckingMixin (Generic [EndpointT ]):
20+ """
21+ Nonblocking socket aliveness checking mix-in.
22+ """
23+
24+ async def check_aliveness (self , endpoint : EndpointT , conn : socket .socket ) -> bool :
25+ try :
26+ if conn .recv (1 , socket .MSG_PEEK ) == b'' :
27+ return False
28+ except BlockingIOError as exc :
29+ if exc .errno != errno .EAGAIN :
30+ raise
31+ except OSError :
32+ return False
33+
34+ return True
35+
36+
37+ class TcpSocketConnectionManager (
38+ SocketAlivenessCheckingMixin [TcpEndpoint ],
39+ BaseConnectionManager [TcpEndpoint , socket .socket ],
40+ ):
2041 """
2142 TCP socket connection manager.
2243 """
@@ -43,33 +64,40 @@ async def dispose(self, endpoint: TcpEndpoint, conn: socket.socket) -> None:
4364 conn .shutdown (socket .SHUT_RDWR )
4465 conn .close ()
4566
46- async def check_aliveness (self , endpoint : TcpEndpoint , conn : socket .socket ) -> bool :
67+
68+ Hostname = str
69+ TcpStreamEndpoint = Tuple [Hostname , Port ]
70+ Stream = Tuple [asyncio .StreamReader , asyncio .StreamWriter ]
71+
72+
73+ class StreamAlivenessCheckingMixin (Generic [EndpointT ]):
74+ """
75+ Asynchronous stream aliveness checking mix-in.
76+ """
77+
78+ async def check_aliveness (self , endpoint : EndpointT , conn : Stream ) -> bool :
79+ reader , writer = conn
80+
4781 try :
48- if conn .recv (1 , socket .MSG_PEEK ) == b'' :
49- return False
50- except BlockingIOError as exc :
51- if exc .errno != errno .EAGAIN :
52- raise
82+ await reader .read (0 )
5383 except OSError :
5484 return False
5585
56- return True
57-
58-
59- Hostname = str
60- TcpStreamEndpoint = Tuple [Hostname , Port ]
61- TcpStream = Tuple [asyncio .StreamReader , asyncio .StreamWriter ]
86+ return not writer .is_closing () and not reader .at_eof ()
6287
6388
64- class TcpStreamConnectionManager (BaseConnectionManager [TcpStreamEndpoint , TcpStream ]):
89+ class TcpStreamConnectionManager (
90+ StreamAlivenessCheckingMixin [TcpStreamEndpoint ],
91+ BaseConnectionManager [TcpStreamEndpoint , Stream ],
92+ ):
6593 """
6694 TCP stream connection manager.
6795 """
6896
6997 def __init__ (self , ssl : Union [None , bool , SSLContext ] = None ):
7098 self ._ssl = ssl
7199
72- async def create (self , endpoint : TcpStreamEndpoint ) -> TcpStream :
100+ async def create (self , endpoint : TcpStreamEndpoint ) -> Stream :
73101 hostname , port = endpoint
74102 server_hostname = hostname if self ._ssl is not None else None
75103
@@ -82,20 +110,10 @@ async def create(self, endpoint: TcpStreamEndpoint) -> TcpStream:
82110
83111 return reader , writer
84112
85- async def dispose (self , endpoint : TcpStreamEndpoint , conn : TcpStream ) -> None :
113+ async def dispose (self , endpoint : TcpStreamEndpoint , conn : Stream ) -> None :
86114 reader , writer = conn
87115 if writer .can_write_eof ():
88116 writer .write_eof ()
89117
90118 writer .close ()
91119 await writer .wait_closed ()
92-
93- async def check_aliveness (self , endpoint : TcpStreamEndpoint , conn : TcpStream ) -> bool :
94- reader , writer = conn
95-
96- try :
97- await reader .read (0 )
98- except OSError :
99- return False
100-
101- return not writer .is_closing () and not reader .at_eof ()
0 commit comments