22import ssl
33from ipaddress import IPv4Address
44from pathlib import Path
5- from typing import Tuple
5+ from typing import AsyncGenerator , Generator , Optional , Tuple
66
77import pytest
8- import pytest_asyncio .plugin
98
109from generic_connection_pool .asyncio import ConnectionPool
1110from generic_connection_pool .contrib .socket_async import TcpSocketConnectionManager , TcpStreamConnectionManager
1211
1312
14- @pytest .fixture
15- async def tcp_server (request : pytest_asyncio .plugin .SubRequest , resource_dir : Path ):
16- params = getattr (request , 'param' , {})
17- user_ssl = params .get ('use_ssl' , False )
18-
19- if user_ssl :
20- context = ssl .create_default_context (ssl .Purpose .CLIENT_AUTH )
21- context .load_cert_chain (resource_dir / 'ssl.cert' , resource_dir / 'ssl.key' )
22- else :
23- context = None
24-
25- hostname , addr , port = 'localhost' , IPv4Address ('127.0.0.1' ), 10000
26-
27- async def echo (reader : asyncio .StreamReader , writer : asyncio .StreamWriter ):
28- data = await reader .read (1024 )
13+ class TCPServer :
14+ @staticmethod
15+ async def echo_handler (reader : asyncio .StreamReader , writer : asyncio .StreamWriter ):
16+ data = await reader .read (1500 )
2917 writer .write (data )
3018 await writer .drain ()
3119 writer .close ()
3220 await writer .wait_closed ()
3321
34- async with (server := await asyncio .start_server (echo , host = hostname , port = port , ssl = context , reuse_port = True )):
35- server_task = asyncio .create_task (server .serve_forever ())
36- yield hostname , addr , port
37- server_task .cancel ()
38- try :
39- await server_task
40- except asyncio .CancelledError :
41- pass
22+ def __init__ (self , hostname : str , port : int , ssl_ctx : Optional [ssl .SSLContext ] = None ):
23+ self ._hostname = hostname
24+ self ._port = port
25+ self ._ssl_ctx = ssl_ctx
26+ self ._server_task : Optional [asyncio .Task [None ]] = None
27+
28+ async def start (self ) -> None :
29+ server = await asyncio .start_server (
30+ self .echo_handler ,
31+ host = self ._hostname ,
32+ port = self ._port ,
33+ ssl = self ._ssl_ctx ,
34+ reuse_port = True ,
35+ )
36+ self ._server_task = asyncio .create_task (server .serve_forever ())
37+
38+ async def stop (self ) -> None :
39+ if (server_task := self ._server_task ) is not None :
40+ server_task .cancel ()
41+ try :
42+ await server_task
43+ except asyncio .CancelledError :
44+ pass
45+
46+
47+ @pytest .fixture
48+ async def tcp_server (port_gen : Generator [int , None , None ]) -> AsyncGenerator [Tuple [IPv4Address , int ], None ]:
49+ addr , port = IPv4Address ('127.0.0.1' ), next (port_gen )
50+ server = TCPServer (str (addr ), port )
51+ await server .start ()
52+ yield addr , port
53+ await server .stop ()
54+
55+
56+ @pytest .fixture
57+ async def ssl_server (resource_dir : Path , port_gen : Generator [int , None , None ]) -> AsyncGenerator [Tuple [str , int ], None ]:
58+ hostname , port = 'localhost' , next (port_gen )
59+ ssl_ctx = ssl .create_default_context (ssl .Purpose .CLIENT_AUTH )
60+ ssl_ctx .load_cert_chain (resource_dir / 'ssl.cert' , resource_dir / 'ssl.key' )
61+
62+ server = TCPServer (hostname , port , ssl_ctx = ssl_ctx )
63+ await server .start ()
64+ yield hostname , port
65+ await server .stop ()
4266
4367
44- async def test_tcp_socket_manager (tcp_server : Tuple [str , IPv4Address , int ]):
68+ async def test_tcp_socket_manager (tcp_server : Tuple [IPv4Address , int ]):
4569 loop = asyncio .get_running_loop ()
46- hostname , addr , port = tcp_server
70+ addr , port = tcp_server
4771
4872 pool = ConnectionPool (TcpSocketConnectionManager ())
4973 async with pool .connection ((addr , port )) as sock :
@@ -55,20 +79,23 @@ async def test_tcp_socket_manager(tcp_server: Tuple[str, IPv4Address, int]):
5579 await pool .close ()
5680
5781
58- @pytest .mark .parametrize (
59- 'use_ssl, tcp_server' , [
60- (True , {'use_ssl' : True }),
61- (False , {'use_ssl' : False }),
62- ],
63- indirect = ['tcp_server' ],
64- )
65- async def test_ssl_stream_manager (resource_dir : Path , use_ssl : bool , tcp_server : Tuple [str , IPv4Address , int ]):
66- hostname , addr , port = tcp_server
67-
68- if use_ssl :
69- ssl_context = ssl .create_default_context (cafile = resource_dir / 'ssl.cert' )
70- else :
71- ssl_context = None
82+ async def test_tcp_stream_manager (resource_dir : Path , tcp_server : Tuple [IPv4Address , int ]):
83+ addr , port = tcp_server
84+
85+ pool = ConnectionPool (TcpStreamConnectionManager (ssl = None ))
86+ async with pool .connection ((str (addr ), port )) as (reader , writer ):
87+ request = b'test'
88+ writer .write (request )
89+ await writer .drain ()
90+ response = await reader .read ()
91+ assert response == request
92+
93+ await pool .close ()
94+
95+
96+ async def test_ssl_stream_manager (resource_dir : Path , ssl_server : Tuple [str , int ]):
97+ hostname , port = ssl_server
98+ ssl_context = ssl .create_default_context (cafile = resource_dir / 'ssl.cert' )
7299
73100 pool = ConnectionPool (TcpStreamConnectionManager (ssl_context ))
74101 async with pool .connection ((hostname , port )) as (reader , writer ):
0 commit comments