|
| 1 | +import asyncio |
| 2 | +import ssl |
| 3 | +from ipaddress import IPv4Address |
| 4 | +from pathlib import Path |
| 5 | +from typing import Tuple |
| 6 | + |
| 7 | +import pytest |
| 8 | +import pytest_asyncio.plugin |
| 9 | + |
| 10 | +from generic_connection_pool.asyncio import ConnectionPool |
| 11 | +from generic_connection_pool.contrib.socket_async import TcpSocketConnectionManager, TcpStreamConnectionManager |
| 12 | + |
| 13 | + |
| 14 | +@pytest.fixture |
| 15 | +async def tcp_server(request: pytest_asyncio.plugin.SubRequest, resource_dir: Path): |
| 16 | + params = getattr(request, 'param', {}) |
| 17 | + user_ssl = params.get('use_ssl', False) |
| 18 | + |
| 19 | + if user_ssl: |
| 20 | + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) |
| 21 | + context.load_cert_chain(resource_dir / 'ssl.cert', resource_dir / 'ssl.key') |
| 22 | + else: |
| 23 | + context = None |
| 24 | + |
| 25 | + hostname, addr, port = 'localhost', IPv4Address('127.0.0.1'), 10000 |
| 26 | + |
| 27 | + async def echo(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): |
| 28 | + data = await reader.read(1024) |
| 29 | + writer.write(data) |
| 30 | + await writer.drain() |
| 31 | + writer.close() |
| 32 | + await writer.wait_closed() |
| 33 | + |
| 34 | + async with (server := await asyncio.start_server(echo, host=hostname, port=port, ssl=context, reuse_port=True)): |
| 35 | + server_task = asyncio.create_task(server.serve_forever()) |
| 36 | + yield hostname, addr, port |
| 37 | + server_task.cancel() |
| 38 | + try: |
| 39 | + await server_task |
| 40 | + except asyncio.CancelledError: |
| 41 | + pass |
| 42 | + |
| 43 | + |
| 44 | +async def test_tcp_socket_manager(tcp_server: Tuple[str, IPv4Address, int]): |
| 45 | + loop = asyncio.get_running_loop() |
| 46 | + hostname, addr, port = tcp_server |
| 47 | + |
| 48 | + pool = ConnectionPool(TcpSocketConnectionManager()) |
| 49 | + async with pool.connection((addr, port)) as sock: |
| 50 | + request = b'test' |
| 51 | + await loop.sock_sendall(sock, request) |
| 52 | + response = await loop.sock_recv(sock, len(request)) |
| 53 | + assert response == request |
| 54 | + |
| 55 | + await pool.close() |
| 56 | + |
| 57 | + |
| 58 | +@pytest.mark.parametrize( |
| 59 | + 'use_ssl, tcp_server', [ |
| 60 | + (True, {'use_ssl': True}), |
| 61 | + (False, {'use_ssl': False}), |
| 62 | + ], |
| 63 | + indirect=['tcp_server'], |
| 64 | +) |
| 65 | +async def test_ssl_stream_manager(resource_dir: Path, use_ssl: bool, tcp_server: Tuple[str, IPv4Address, int]): |
| 66 | + hostname, addr, port = tcp_server |
| 67 | + |
| 68 | + if use_ssl: |
| 69 | + ssl_context = ssl.create_default_context(cafile=resource_dir / 'ssl.cert') |
| 70 | + else: |
| 71 | + ssl_context = None |
| 72 | + |
| 73 | + pool = ConnectionPool(TcpStreamConnectionManager(ssl_context)) |
| 74 | + async with pool.connection((hostname, port)) as (reader, writer): |
| 75 | + request = b'test' |
| 76 | + writer.write(request) |
| 77 | + await writer.drain() |
| 78 | + response = await reader.read() |
| 79 | + assert response == request |
| 80 | + |
| 81 | + await pool.close() |
0 commit comments