11import asyncio
2+ import socket
23import ssl
34from ipaddress import IPv4Address
45from pathlib import Path
1314class TCPServer :
1415 @staticmethod
1516 async def echo_handler (reader : asyncio .StreamReader , writer : asyncio .StreamWriter ):
16- data = await reader .read (1500 )
17- writer .write (data )
18- await writer .drain ()
17+ while data := await reader .read (1024 ):
18+ writer .write (data )
19+ await writer .drain ()
20+
1921 writer .close ()
2022 await writer .wait_closed ()
2123
@@ -31,6 +33,7 @@ async def start(self) -> None:
3133 host = self ._hostname ,
3234 port = self ._port ,
3335 ssl = self ._ssl_ctx ,
36+ family = socket .AF_INET ,
3437 reuse_port = True ,
3538 )
3639 self ._server_task = asyncio .create_task (server .serve_forever ())
@@ -68,44 +71,73 @@ async def ssl_server(
6871 await server .stop ()
6972
7073
74+ @pytest .mark .timeout (5.0 )
7175async def test_tcp_socket_manager (tcp_server : Tuple [IPv4Address , int ]):
7276 loop = asyncio .get_running_loop ()
7377 addr , port = tcp_server
7478
7579 pool = ConnectionPool (TcpSocketConnectionManager ())
76- async with pool .connection ((addr , port )) as sock :
77- request = b'test'
78- await loop .sock_sendall (sock , request )
79- response = await loop .sock_recv (sock , len (request ))
80- assert response == request
80+
81+ attempts = 3
82+ request = b'test'
83+ for _ in range (attempts ):
84+ async with pool .connection ((addr , port )) as sock1 :
85+ await loop .sock_sendall (sock1 , request )
86+ response = await loop .sock_recv (sock1 , len (request ))
87+ assert response == request
88+
89+ async with pool .connection ((addr , port )) as sock2 :
90+ await loop .sock_sendall (sock2 , request )
91+ response = await loop .sock_recv (sock2 , len (request ))
92+ assert response == request
8193
8294 await pool .close ()
8395
8496
97+ @pytest .mark .timeout (5.0 )
8598async def test_tcp_stream_manager (resource_dir : Path , tcp_server : Tuple [IPv4Address , int ]):
8699 addr , port = tcp_server
87100
88101 pool = ConnectionPool (TcpStreamConnectionManager (ssl = None ))
89- async with pool .connection ((str (addr ), port )) as (reader , writer ):
90- request = b'test'
91- writer .write (request )
92- await writer .drain ()
93- response = await reader .read ()
94- assert response == request
102+
103+ attempts = 3
104+ request = b'test'
105+ for _ in range (attempts ):
106+ async with pool .connection ((str (addr ), port )) as (reader1 , writer1 ):
107+ writer1 .write (request )
108+ await writer1 .drain ()
109+ response = await reader1 .read (len (request ))
110+ assert response == request
111+
112+ async with pool .connection ((str (addr ), port )) as (reader2 , writer2 ):
113+ writer2 .write (request )
114+ await writer2 .drain ()
115+ response = await reader2 .read (len (request ))
116+ assert response == request
95117
96118 await pool .close ()
97119
98120
121+ @pytest .mark .timeout (5.0 )
99122async def test_ssl_stream_manager (resource_dir : Path , ssl_server : Tuple [str , int ]):
100123 hostname , port = ssl_server
101124 ssl_context = ssl .create_default_context (cafile = resource_dir / 'ssl.cert' )
102125
103126 pool = ConnectionPool (TcpStreamConnectionManager (ssl_context ))
104- async with pool .connection ((hostname , port )) as (reader , writer ):
105- request = b'test'
106- writer .write (request )
107- await writer .drain ()
108- response = await reader .read ()
109- assert response == request
127+
128+ attempts = 3
129+ request = b'test'
130+ for _ in range (attempts ):
131+ async with pool .connection ((hostname , port )) as (reader1 , writer1 ):
132+ writer1 .write (request )
133+ await writer1 .drain ()
134+ response = await reader1 .read (len (request ))
135+ assert response == request
136+
137+ async with pool .connection ((hostname , port )) as (reader2 , writer2 ):
138+ writer2 .write (request )
139+ await writer2 .drain ()
140+ response = await reader2 .read (len (request ))
141+ assert response == request
110142
111143 await pool .close ()
0 commit comments