Skip to content

Commit bf247fb

Browse files
authored
Integration test client
1 parent fa118e3 commit bf247fb

File tree

1 file changed

+303
-0
lines changed

1 file changed

+303
-0
lines changed

tests/testclient.py

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
import os
2+
import sys
3+
import asyncio
4+
import socket
5+
import subprocess
6+
import time
7+
from contextlib import asynccontextmanager
8+
from typing import Optional, Dict, Any
9+
10+
import aiohttp
11+
import picows
12+
import pytest
13+
import pytest_asyncio
14+
15+
16+
class TransitIntegrationClient:
17+
"""Integration test client for Transit.sh with real Redis and uvicorn processes."""
18+
19+
def __init__(self, base_url: str = "http://localhost:8080", ws_base_url: str = "ws://localhost:8080"):
20+
self.base_url = base_url
21+
self.ws_base_url = ws_base_url
22+
self.http_session: Optional[aiohttp.ClientSession] = None
23+
24+
async def __aenter__(self):
25+
self.http_session = aiohttp.ClientSession()
26+
return self
27+
28+
async def __aexit__(self, exc_type, exc_val, exc_tb):
29+
if self.http_session:
30+
await self.http_session.close()
31+
32+
# HTTP Methods
33+
async def get(self, path: str, **kwargs) -> aiohttp.ClientResponse:
34+
"""Make a GET request."""
35+
url = f"{self.base_url}{path}"
36+
return await self.http_session.get(url, **kwargs)
37+
38+
async def put(self, path: str, **kwargs) -> aiohttp.ClientResponse:
39+
"""Make a PUT request."""
40+
url = f"{self.base_url}{path}"
41+
return await self.http_session.put(url, **kwargs)
42+
43+
async def post(self, path: str, **kwargs) -> aiohttp.ClientResponse:
44+
"""Make a POST request."""
45+
url = f"{self.base_url}{path}"
46+
return await self.http_session.post(url, **kwargs)
47+
48+
# WebSocket Methods
49+
@asynccontextmanager
50+
async def websocket(self, path: str):
51+
"""Create a WebSocket connection using picows."""
52+
url = f"{self.ws_base_url}{path}"
53+
54+
ws_client = await picows.ws_connect(picows.WSTransport, url)
55+
try:
56+
yield WebSocketWrapper(ws_client)
57+
finally:
58+
await ws_client.close()
59+
60+
61+
class WebSocketWrapper:
62+
"""Wrapper around picows WebSocket to provide a simpler API."""
63+
64+
def __init__(self, ws: picows.WSTransport):
65+
self._ws = ws
66+
67+
async def send_text(self, data: str):
68+
"""Send text data."""
69+
await self._ws.send(picows.WSMsgType.TEXT, data.encode())
70+
71+
async def send_bytes(self, data: bytes):
72+
"""Send binary data."""
73+
await self._ws.send(picows.WSMsgType.BINARY, data)
74+
75+
async def send_json(self, data: Dict[str, Any]):
76+
"""Send JSON data."""
77+
import json
78+
await self.send_text(json.dumps(data))
79+
80+
async def receive(self) -> tuple[picows.WSMsgType, bytes]:
81+
"""Receive raw message."""
82+
return await self._ws.recv()
83+
84+
async def receive_text(self) -> str:
85+
"""Receive text message."""
86+
msg_type, data = await self.receive()
87+
if msg_type != picows.WSMsgType.TEXT:
88+
raise ValueError(f"Expected text message, got {msg_type}")
89+
return data.decode()
90+
91+
async def receive_bytes(self) -> bytes:
92+
"""Receive binary message."""
93+
msg_type, data = await self.receive()
94+
if msg_type != picows.WSMsgType.BINARY:
95+
raise ValueError(f"Expected binary message, got {msg_type}")
96+
return data
97+
98+
async def receive_json(self) -> Dict[str, Any]:
99+
"""Receive JSON message."""
100+
import json
101+
text = await self.receive_text()
102+
return json.loads(text)
103+
104+
async def close(self):
105+
"""Close the WebSocket connection."""
106+
await self._ws.close()
107+
108+
109+
def find_free_port() -> int:
110+
"""Find a free port on localhost."""
111+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
112+
s.bind(('', 0))
113+
s.listen(1)
114+
port = s.getsockname()[1]
115+
return port
116+
117+
118+
async def wait_for_port(host: str, port: int, timeout: float = 30.0):
119+
"""Wait for a service to start listening on a port."""
120+
start_time = time.time()
121+
while time.time() - start_time < timeout:
122+
try:
123+
reader, writer = await asyncio.open_connection(host, port)
124+
writer.close()
125+
await writer.wait_closed()
126+
return True
127+
except (ConnectionRefusedError, OSError):
128+
await asyncio.sleep(0.1)
129+
raise TimeoutError(f"Service on {host}:{port} did not start within {timeout} seconds")
130+
131+
132+
async def wait_for_redis(redis_port: int, timeout: float = 30.0):
133+
"""Wait for Redis to be ready by attempting to connect."""
134+
import redis.asyncio as redis
135+
136+
start_time = time.time()
137+
while time.time() - start_time < timeout:
138+
try:
139+
client = redis.Redis(host='localhost', port=redis_port)
140+
await client.ping()
141+
await client.aclose()
142+
return True
143+
except (ConnectionRefusedError, redis.ConnectionError):
144+
await asyncio.sleep(0.1)
145+
raise TimeoutError(f"Redis on port {redis_port} did not start within {timeout} seconds")
146+
147+
148+
class ServerProcessManager:
149+
"""Manages Redis and uvicorn server processes."""
150+
151+
def __init__(self):
152+
self.redis_process: Optional[subprocess.Popen] = None
153+
self.uvicorn_process: Optional[subprocess.Popen] = None
154+
self.redis_port: Optional[int] = None
155+
self.uvicorn_port: Optional[int] = None
156+
157+
async def start_redis(self) -> int:
158+
"""Start Redis server and return its port."""
159+
self.redis_port = find_free_port()
160+
161+
# Start Redis with minimal config
162+
self.redis_process = subprocess.Popen(
163+
[
164+
'redis-server',
165+
'--port', str(self.redis_port),
166+
'--save', '', # Disable persistence
167+
'--appendonly', 'no', # Disable AOF
168+
'--loglevel', 'warning'
169+
],
170+
stdout=subprocess.PIPE,
171+
stderr=subprocess.PIPE
172+
)
173+
174+
# Wait for Redis to be ready
175+
await wait_for_redis(self.redis_port)
176+
return self.redis_port
177+
178+
async def start_uvicorn(self, redis_url: str) -> int:
179+
"""Start uvicorn server and return its port."""
180+
self.uvicorn_port = find_free_port()
181+
182+
# Set environment for the uvicorn process
183+
env = os.environ.copy()
184+
env['REDIS_URL'] = redis_url
185+
186+
# Start uvicorn
187+
self.uvicorn_process = subprocess.Popen(
188+
[
189+
sys.executable, '-m', 'uvicorn',
190+
'app:app',
191+
'--host', '0.0.0.0',
192+
'--port', str(self.uvicorn_port),
193+
'--workers', '1',
194+
'--loop', 'uvloop',
195+
'--ws', 'websockets',
196+
'--log-level', 'warning'
197+
],
198+
env=env,
199+
stdout=subprocess.PIPE,
200+
stderr=subprocess.PIPE,
201+
cwd=os.path.dirname(os.path.abspath(__file__)) # Ensure we're in the project root
202+
)
203+
204+
# Wait for uvicorn to be ready
205+
await wait_for_port('localhost', self.uvicorn_port)
206+
return self.uvicorn_port
207+
208+
def terminate_process(self, process: subprocess.Popen, name: str):
209+
"""Terminate a process gracefully."""
210+
if process and process.poll() is None:
211+
process.terminate()
212+
try:
213+
process.wait(timeout=5)
214+
except subprocess.TimeoutExpired:
215+
print(f"Warning: {name} did not terminate gracefully, killing...")
216+
process.kill()
217+
process.wait()
218+
219+
def cleanup(self):
220+
"""Clean up all processes."""
221+
self.terminate_process(self.redis_process, "Redis")
222+
self.terminate_process(self.uvicorn_process, "uvicorn")
223+
224+
225+
@asynccontextmanager
226+
async def transit_test_servers():
227+
"""Context manager to start and stop Redis and uvicorn servers."""
228+
manager = ServerProcessManager()
229+
230+
try:
231+
# Start Redis
232+
redis_port = await manager.start_redis()
233+
redis_url = f"redis://localhost:{redis_port}"
234+
print(f"Started Redis on port {redis_port}")
235+
236+
# Start uvicorn
237+
uvicorn_port = await manager.start_uvicorn(redis_url)
238+
print(f"Started uvicorn on port {uvicorn_port}")
239+
240+
# Create and yield the test client
241+
base_url = f"http://localhost:{uvicorn_port}"
242+
ws_base_url = f"ws://localhost:{uvicorn_port}"
243+
244+
async with TransitIntegrationClient(base_url, ws_base_url) as client:
245+
yield client
246+
247+
finally:
248+
# Clean up processes
249+
manager.cleanup()
250+
print("Cleaned up test servers")
251+
252+
253+
# Pytest fixtures
254+
@pytest_asyncio.fixture(scope="session")
255+
async def integration_client():
256+
"""Session-scoped fixture that provides the integration test client."""
257+
async with transit_test_servers() as client:
258+
yield client
259+
260+
261+
# Example usage in tests
262+
@pytest.mark.asyncio
263+
async def test_health_endpoint(integration_client: TransitIntegrationClient):
264+
"""Test the health endpoint with real servers."""
265+
response = await integration_client.get("/health")
266+
assert response.status == 200
267+
data = await response.json()
268+
assert data == {"status": "ok"}
269+
270+
271+
@pytest.mark.asyncio
272+
async def test_websocket_upload_integration(integration_client: TransitIntegrationClient):
273+
"""Test WebSocket upload with real servers."""
274+
transfer_id = "test-integration-transfer"
275+
276+
async with integration_client.websocket(f"/send/{transfer_id}") as ws:
277+
# Send file metadata
278+
await ws.send_json({
279+
'file_name': 'test.txt',
280+
'file_size': 11,
281+
'file_type': 'text/plain'
282+
})
283+
284+
# In a real test, you'd have a receiver connect here
285+
# For this example, we'll just verify the connection works
286+
287+
# The actual test would timeout waiting for receiver
288+
# This is just to show the client works
289+
290+
291+
@pytest.mark.asyncio
292+
async def test_http_download_not_found(integration_client: TransitIntegrationClient):
293+
"""Test HTTP download for non-existent transfer."""
294+
response = await integration_client.get("/nonexistent-transfer")
295+
assert response.status == 404
296+
data = await response.json()
297+
assert "not found" in data["detail"].lower()
298+
299+
300+
# Requirements to add to your test dependencies:
301+
# aiohttp>=3.9.0
302+
# picows>=1.0.0
303+
# pytest-asyncio>=0.21.0

0 commit comments

Comments
 (0)