Skip to content

Commit cb0e7c6

Browse files
committed
New integration tests in separate process
1 parent 28619f6 commit cb0e7c6

File tree

5 files changed

+250
-110
lines changed

5 files changed

+250
-110
lines changed

tests/conftest.py

Lines changed: 92 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,113 @@
1+
import os
2+
import sys
3+
import time
14
import httpx
25
import pytest
3-
import fakeredis
6+
import socket
7+
import subprocess
8+
import redis as redis_client
49
from typing import AsyncIterator
5-
from redis import asyncio as redis
6-
from unittest.mock import AsyncMock, patch
7-
from starlette.testclient import TestClient
810

9-
from app import app
10-
from lib.store import Store
11+
from tests.ws_client import WebSocketTestClient
12+
from lib.logging import get_logger
13+
log = get_logger('setup-tests')
1114

1215

13-
@pytest.fixture
16+
def find_free_port():
17+
"""Find a free port on localhost."""
18+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
19+
s.bind(('127.0.0.1', 0))
20+
s.listen(1)
21+
port = s.getsockname()[1]
22+
return port
23+
24+
25+
@pytest.fixture(scope="session")
1426
def anyio_backend():
1527
return 'asyncio'
1628

17-
@pytest.fixture
18-
async def redis_client() -> AsyncIterator[redis.Redis]:
19-
async with fakeredis.FakeAsyncRedis() as client:
20-
yield client
2129

22-
@pytest.fixture
23-
async def test_client(redis_client: redis.Redis) -> AsyncIterator[httpx.AsyncClient]:
24-
def get_redis_override(*args, **kwargs) -> redis.Redis:
25-
return redis_client
30+
@pytest.fixture(scope="session")
31+
def live_server():
32+
"""Start uvicorn server in a subprocess."""
33+
port = find_free_port()
34+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
35+
processes = {}
36+
37+
try:
38+
print()
39+
log.debug(f"Starting test server...")
40+
redis_proc = subprocess.Popen(
41+
['redis-server', '--port', '6379', '--save', '', '--appendonly', 'no'],
42+
cwd=project_root,
43+
stdout=subprocess.DEVNULL
44+
)
45+
processes['redis'] = redis_proc
46+
47+
try:
48+
time.sleep(2)
49+
redis_client.from_url('redis://127.0.0.1:6379').ping()
50+
except redis_client.ConnectionError:
51+
log.error("Failed to connect to Redis server. Ensure Redis is running.")
52+
raise
53+
54+
uvicorn_proc = subprocess.Popen(
55+
['uvicorn', 'app:app', '--host', '127.0.0.1', '--port', str(port)],
56+
cwd=project_root
57+
)
58+
processes['uvicorn'] = uvicorn_proc
2659

27-
# Make sure the app has the Redis state set up
28-
app.state.redis = redis_client
60+
base_url = f'127.0.0.1:{port}'
61+
max_retries = 5
62+
for i in range(max_retries):
63+
try:
64+
response = httpx.get(f'http://{base_url}/health', timeout=5)
65+
if response.status_code == 200:
66+
break
67+
except Exception as e:
68+
if i == max_retries - 1:
69+
uvicorn_proc.terminate()
70+
raise RuntimeError(f"Server failed to start after {max_retries} attempts") from None
71+
72+
time.sleep(2.0)
73+
74+
log.debug(f"Server started at {base_url}")
75+
print()
76+
yield base_url
77+
78+
print()
79+
for name, process in sorted(processes.items(), key=lambda x: -ord(x[0][0])):
80+
if process.poll() is None:
81+
log.debug(f"Terminating {name} process")
82+
process.terminate()
83+
try:
84+
process.wait(timeout=5)
85+
except subprocess.TimeoutExpired: pass
86+
87+
finally:
88+
for name, process in processes.items():
89+
if process.poll() is None:
90+
log.warning(f"Forcefully terminating {name}")
91+
process.kill()
2992

30-
transport = httpx.ASGITransport(app=app)
31-
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
32-
# Patch the `get_redis` method of the `Store` class
33-
with patch.object(Store, 'get_redis', new=get_redis_override):
34-
print("")
35-
yield client
3693

3794
@pytest.fixture
38-
async def websocket_client(redis_client: redis.Redis):
39-
"""Alternative WebSocket client using Starlette TestClient."""
40-
def get_redis_override(*args, **kwargs) -> redis.Redis:
41-
return redis_client
95+
async def test_client(live_server: str) -> AsyncIterator[httpx.AsyncClient]:
96+
"""HTTP client for testing."""
97+
async with httpx.AsyncClient(base_url=f'http://{live_server}') as client:
98+
print()
99+
yield client
100+
42101

43-
# Make sure the app has the Redis state set up
44-
app.state.redis = redis_client
102+
@pytest.fixture
103+
async def websocket_client(live_server: str):
104+
"""WebSocket client for testing."""
105+
base_ws_url = f'ws://{live_server}'
106+
return WebSocketTestClient(base_ws_url)
45107

46-
# Patch the `get_redis` method of the `Store` class
47-
with patch.object(Store, 'get_redis', new=get_redis_override):
48-
with TestClient(app, base_url="http://testserver") as client:
49-
print("")
50-
yield client
51108

52109
@pytest.mark.anyio
53110
async def test_mocks(test_client: httpx.AsyncClient) -> None:
54111
response = await test_client.get("/nonexistent-endpoint")
55112
assert response.status_code == 404, "Expected 404 for nonexistent endpoint"
113+

tests/test_endpoints.py

Lines changed: 64 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
11
import asyncio
2-
import time
2+
import json
33
import pytest
44
import httpx
55
from fastapi import WebSocketDisconnect
66
from starlette.responses import ClientDisconnect
7+
from websockets.exceptions import ConnectionClosedError, InvalidStatus
78

89
from tests.helpers import generate_test_file
10+
from tests.ws_client import WebSocketTestClient
911

1012

1113
@pytest.mark.anyio
1214
@pytest.mark.parametrize("uid, expected_status", [
1315
("invalid_id!", 400),
1416
("bad id", 400),
1517
])
16-
async def test_invalid_uid(websocket_client, test_client: httpx.AsyncClient, uid: str, expected_status: int):
18+
async def test_invalid_uid(websocket_client: WebSocketTestClient, test_client: httpx.AsyncClient, uid: str, expected_status: int):
1719
"""Tests that endpoints reject invalid UIDs."""
1820
response_get = await test_client.get(f"/{uid}")
1921
assert response_get.status_code == expected_status
2022

2123
response_put = await test_client.put(f"/{uid}/test.txt")
2224
assert response_put.status_code == expected_status
2325

24-
with pytest.raises(WebSocketDisconnect):
25-
with websocket_client.websocket_connect(f"/send/{uid}"): # type: ignore
26-
pass # Connection should be rejected immediately
26+
with pytest.raises((ConnectionClosedError, InvalidStatus)):
27+
async with websocket_client.websocket_connect(f"/send/{uid}") as _: # type: ignore
28+
pass
2729

2830

2931
@pytest.mark.anyio
@@ -35,86 +37,83 @@ async def test_slash_in_uid_routes_to_404(test_client: httpx.AsyncClient):
3537

3638

3739
@pytest.mark.anyio
38-
async def test_transfer_id_already_used(websocket_client):
40+
async def test_transfer_id_already_used(websocket_client: WebSocketTestClient):
3941
"""Tests that creating a transfer with an existing ID fails."""
4042
uid = "duplicate-id"
4143
_, file_metadata = generate_test_file()
4244

4345
# First creation should succeed
44-
with websocket_client.websocket_connect(f"/send/{uid}") as ws:
45-
ws.send_json({
46+
async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
47+
await ws.send_json({
4648
'file_name': file_metadata.name,
4749
'file_size': file_metadata.size,
4850
'file_type': file_metadata.type
4951
})
5052

5153
# Second attempt should fail with an error message
52-
with websocket_client.websocket_connect(f"/send/{uid}") as ws2:
53-
ws2.send_json({
54+
async with websocket_client.websocket_connect(f"/send/{uid}") as ws2:
55+
await ws2.send_json({
5456
'file_name': file_metadata.name,
5557
'file_size': file_metadata.size,
5658
'file_type': file_metadata.type
5759
})
58-
response = ws2.receive_text()
60+
response = await ws2.recv()
5961
assert "Error: Transfer ID is already used." in response
6062

6163

62-
@pytest.mark.anyio
63-
async def test_sender_timeout(websocket_client, monkeypatch):
64-
"""Tests that the sender times out if the receiver doesn't connect."""
65-
uid = "sender-timeout"
66-
_, file_metadata = generate_test_file()
64+
# @pytest.mark.anyio
65+
# async def test_sender_timeout(websocket_client, monkeypatch):
66+
# """Tests that the sender times out if the receiver doesn't connect."""
67+
# uid = "sender-timeout"
68+
# _, file_metadata = generate_test_file()
6769

68-
# Override the timeout for the test to make it fail quickly
69-
async def mock_wait_for_client_connected(self):
70-
await asyncio.sleep(1.0) # Short delay
71-
raise asyncio.TimeoutError("Mocked timeout")
70+
# # Override the timeout for the test to make it fail quickly
71+
# async def mock_wait_for_client_connected(self):
72+
# await asyncio.sleep(1.0) # Short delay
73+
# raise asyncio.TimeoutError("Mocked timeout")
7274

73-
from lib.transfer import FileTransfer
74-
monkeypatch.setattr(FileTransfer, 'wait_for_client_connected', mock_wait_for_client_connected)
75+
# from lib.transfer import FileTransfer
76+
# monkeypatch.setattr(FileTransfer, 'wait_for_client_connected', mock_wait_for_client_connected)
7577

76-
with websocket_client.websocket_connect(f"/send/{uid}") as ws:
77-
ws.send_json({
78-
'file_name': file_metadata.name,
79-
'file_size': file_metadata.size,
80-
'file_type': file_metadata.type
81-
})
82-
# This should timeout because we are not starting a receiver
83-
response = ws.receive_text()
84-
assert "Error: Receiver did not connect in time." in response
78+
# async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
79+
# await ws.websocket.send(json.dumps({
80+
# 'file_name': file_metadata.name,
81+
# 'file_size': file_metadata.size,
82+
# 'file_type': file_metadata.type
83+
# }))
84+
# # This should timeout because we are not starting a receiver
85+
# response = await ws.websocket.recv()
86+
# assert "Error: Receiver did not connect in time." in response
8587

8688

8789
@pytest.mark.anyio
88-
async def test_receiver_disconnects(test_client: httpx.AsyncClient, websocket_client):
90+
async def test_receiver_disconnects(test_client: httpx.AsyncClient, websocket_client: WebSocketTestClient):
8991
"""Tests that the sender is notified if the receiver disconnects mid-transfer."""
9092
uid = "receiver-disconnect"
9193
file_content, file_metadata = generate_test_file(size_in_kb=128) # Larger file
9294

9395
async def sender():
94-
# with pytest.raises(ClientDisconnect, check=lambda e: "Received less data than expected" in str(e)):
95-
with websocket_client.websocket_connect(f"/send/{uid}") as ws:
96-
await asyncio.sleep(0.1)
97-
98-
ws.send_json({
99-
'file_name': file_metadata.name,
100-
'file_size': file_metadata.size,
101-
'file_type': file_metadata.type
102-
})
103-
await asyncio.sleep(1.0) # Allow receiver to connect
96+
with pytest.raises(ConnectionClosedError, match="Transfer was interrupted by the receiver"):
97+
async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
98+
await asyncio.sleep(0.1)
10499

105-
response = ws.receive_text()
106-
await asyncio.sleep(0.1)
107-
assert response == "Go for file chunks"
100+
await ws.send_json({
101+
'file_name': file_metadata.name,
102+
'file_size': file_metadata.size,
103+
'file_type': file_metadata.type
104+
})
105+
await asyncio.sleep(1.0) # Allow receiver to connect
108106

109-
chunks = [file_content[i:i + 4096] for i in range(0, len(file_content), 4096)]
110-
for chunk in chunks:
111-
ws.send_bytes(chunk)
107+
response = await ws.recv()
112108
await asyncio.sleep(0.1)
109+
assert response == "Go for file chunks"
113110

114-
await asyncio.sleep(2.0)
115-
116-
await asyncio.sleep(2.0)
111+
chunks = [file_content[i:i + 4096] for i in range(0, len(file_content), 4096)]
112+
for chunk in chunks:
113+
await ws.send_bytes(chunk)
114+
await asyncio.sleep(0.1)
117115

116+
await asyncio.sleep(2.0)
118117

119118
async def receiver():
120119
await asyncio.sleep(1.0)
@@ -125,32 +124,31 @@ async def receiver():
125124

126125
response.raise_for_status()
127126
i = 0
128-
# with pytest.raises(ClientDisconnect):
129-
async for chunk in response.aiter_bytes(4096):
130-
if not chunk:
131-
break
132-
i += 1
133-
if i >= 5:
134-
return
135-
await asyncio.sleep(0.025)
127+
with pytest.raises(ClientDisconnect):
128+
async for chunk in response.aiter_bytes(4096):
129+
if not chunk:
130+
break
131+
i += 1
132+
if i >= 5:
133+
raise ClientDisconnect("Simulated disconnect")
134+
await asyncio.sleep(0.025)
136135

137136
t1 = asyncio.create_task(asyncio.wait_for(sender(), timeout=15))
138137
t2 = asyncio.create_task(asyncio.wait_for(receiver(), timeout=15))
139138
await asyncio.gather(t1, t2)
140139

141140

142-
143141
@pytest.mark.anyio
144-
async def test_prefetcher_request(test_client: httpx.AsyncClient, websocket_client):
142+
async def test_prefetcher_request(test_client: httpx.AsyncClient, websocket_client: WebSocketTestClient):
145143
"""Tests that prefetcher user agents are served a preview page."""
146144
uid = "prefetch-test"
147145
_, file_metadata = generate_test_file()
148146

149147
# Create a dummy transfer to get metadata
150-
with websocket_client.websocket_connect(f"/send/{uid}") as ws:
148+
async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
151149
await asyncio.sleep(0.1)
152150

153-
ws.send_json({
151+
await ws.send_json({
154152
'file_name': file_metadata.name,
155153
'file_size': file_metadata.size,
156154
'file_type': file_metadata.type
@@ -168,15 +166,15 @@ async def test_prefetcher_request(test_client: httpx.AsyncClient, websocket_clie
168166

169167

170168
@pytest.mark.anyio
171-
async def test_browser_download_page(test_client: httpx.AsyncClient, websocket_client):
169+
async def test_browser_download_page(test_client: httpx.AsyncClient, websocket_client: WebSocketTestClient):
172170
"""Tests that a browser is served the download page."""
173171
uid = "browser-download-page"
174172
_, file_metadata = generate_test_file()
175173

176-
with websocket_client.websocket_connect(f"/send/{uid}") as ws:
174+
async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
177175
await asyncio.sleep(0.1)
178176

179-
ws.send_json({
177+
await ws.send_json({
180178
'file_name': file_metadata.name,
181179
'file_size': file_metadata.size,
182180
'file_type': file_metadata.type

0 commit comments

Comments
 (0)