Skip to content

Commit cb0c714

Browse files
committed
New unit tests
1 parent 60adf9a commit cb0c714

File tree

7 files changed

+398
-411
lines changed

7 files changed

+398
-411
lines changed

.github/workflows/unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
run: |
2626
python -m pip install --upgrade pip
2727
pip install -r requirements.txt
28-
pip install pytest
28+
pip install pytest fakeredis
2929
3030
- name: Run tests
3131
run: pytest --disable-pytest-warnings -v tests/

app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ async def lifespan(app: FastAPI):
2929

3030
# App. routes
3131
app.include_router(misc_router)
32+
app.include_router(ws_router) # WebSocket routes first to avoid conflicts
3233
app.include_router(http_router)
33-
app.include_router(ws_router)
3434

3535
# Static files
3636
app.mount('/', StaticFiles(directory='static', html=True), name='static')

tests/conftest.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,55 @@
1-
import asyncio
1+
import httpx
22
import pytest
3+
import fakeredis
4+
from typing import AsyncIterator
5+
from redis import asyncio as redis
36
from unittest.mock import AsyncMock, patch
7+
from starlette.testclient import TestClient
48

9+
from app import app
10+
from lib.store import Store
511

6-
@pytest.fixture(scope="session")
7-
def event_loop():
8-
"""Create an instance of the default event loop for the test session."""
9-
loop = asyncio.new_event_loop()
10-
yield loop
11-
loop.close()
12-
13-
@pytest.fixture(autouse=True)
14-
def setup_test_environment():
15-
"""Set up test environment before each test."""
16-
# Mock Redis client to avoid needing actual Redis instance
17-
with patch('app.redis') as mock_redis:
18-
mock_redis.close = AsyncMock()
19-
yield mock_redis
12+
13+
@pytest.fixture
14+
def anyio_backend():
15+
return 'asyncio'
16+
17+
@pytest.fixture
18+
async def redis_client() -> AsyncIterator[redis.Redis]:
19+
async with fakeredis.FakeAsyncRedis() as client:
20+
yield client
21+
22+
@pytest.fixture
23+
async def test_client(redis_client: redis.Redis) -> AsyncIterator[httpx.AsyncClient]:
24+
def get_redis_override(*args, **kwargs) -> redis.Redis:
25+
return redis_client
26+
27+
# Make sure the app has the Redis state set up
28+
app.state.redis = redis_client
29+
30+
transport = httpx.ASGITransport(app=app)
31+
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
32+
# Patch the `get_redis` method of the `Store` class
33+
with patch.object(Store, 'get_redis', new=get_redis_override):
34+
print("")
35+
yield client
36+
37+
@pytest.fixture
38+
async def websocket_client(redis_client: redis.Redis):
39+
"""Alternative WebSocket client using Starlette TestClient."""
40+
def get_redis_override(*args, **kwargs) -> redis.Redis:
41+
return redis_client
42+
43+
# Make sure the app has the Redis state set up
44+
app.state.redis = redis_client
45+
46+
# Patch the `get_redis` method of the `Store` class
47+
with patch.object(Store, 'get_redis', new=get_redis_override):
48+
with TestClient(app) as client:
49+
print("")
50+
yield client
51+
52+
@pytest.mark.anyio
53+
async def test_mocks(test_client: httpx.AsyncClient) -> None:
54+
response = await test_client.get("/nonexistent-endpoint")
55+
assert response.status_code == 404, "Expected 404 for nonexistent endpoint"

tests/helpers.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from string import ascii_letters
2+
from itertools import islice, repeat, chain
3+
from typing import Tuple, Iterable, Iterator
4+
from annotated_types import T
5+
6+
from lib.metadata import FileMetadata
7+
8+
9+
def generate_test_file(size_in_kb: int = 10) -> tuple[bytes, FileMetadata]:
10+
"""Generates a test file with specified size in KB."""
11+
chunk_generator = ((letter * 1024).encode() for letter in chain.from_iterable(repeat(ascii_letters)))
12+
content = b''.join(next(chunk_generator) for _ in range(size_in_kb))
13+
14+
metadata = FileMetadata(
15+
name="test_file.bin",
16+
size=len(content),
17+
content_type="application/octet-stream"
18+
)
19+
return content, metadata
20+
21+
22+
def batched(iterable: Iterable[T], chunk_size: int) -> Iterator[Tuple[T, ...]]:
23+
"Batch data into lists of length n. The last batch may be shorter."
24+
# batched('ABCDEFG', 3) --> ABC DEF G
25+
it = iter(iterable)
26+
while True:
27+
batch = bytes(islice(it, chunk_size))
28+
if not batch:
29+
return
30+
yield batch

0 commit comments

Comments
 (0)