Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 24 additions & 32 deletions lib/store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
import asyncio
import anyio
import redis.asyncio as redis
from redis.asyncio.client import PubSub
from typing import Optional, Annotated

from lib.logging import HasLogging, get_logger
Expand Down Expand Up @@ -39,19 +40,19 @@ def key(self, name: str) -> str:

async def _wait_for_queue_space(self, maxsize: int) -> None:
while await self.redis.llen(self._k_queue) >= maxsize:
await asyncio.sleep(0.5)
await anyio.sleep(0.5)

async def put_in_queue(self, data: bytes, maxsize: int = 16, timeout: float = 20.0) -> None:
"""Add data to the transfer queue with backpressure control."""
async with asyncio.timeout(timeout):
with anyio.fail_after(timeout):
await self._wait_for_queue_space(maxsize)
await self.redis.lpush(self._k_queue, data)

async def get_from_queue(self, timeout: float = 20.0) -> bytes:
"""Get data from the transfer queue with timeout."""
result = await self.redis.brpop([self._k_queue], timeout=timeout)
if not result:
raise asyncio.TimeoutError("Timeout waiting for data")
raise TimeoutError("Timeout waiting for data")

_, data = result
return data
Expand All @@ -66,46 +67,37 @@ async def set_event(self, event_name: str, expiry: float = 300.0) -> None:
await self.redis.set(event_marker_key, '1', ex=int(expiry))
await self.redis.publish(event_key, '1')

async def _poll_marker(self, event_key: str) -> None:
"""Poll for event marker existence."""
event_marker_key = f'{event_key}:marker'
while not await self.redis.exists(event_marker_key):
await anyio.sleep(1)

async def _listen_for_message(self, pubsub: PubSub, event_key: str) -> None:
"""Listen for pubsub messages."""
await pubsub.subscribe(event_key)
async for message in pubsub.listen():
if message and message['type'] == 'message':
return

async def wait_for_event(self, event_name: str, timeout: float = 300.0) -> None:
"""Wait for an event to be set for this transfer."""
event_key = self.key(event_name)
event_marker_key = f'{event_key}:marker'
pubsub = self.redis.pubsub(ignore_subscribe_messages=True)
await pubsub.subscribe(event_key)

async def _poll_marker():
while not await self.redis.exists(event_marker_key):
await asyncio.sleep(1)
self.debug(f">> POLL: Event '{event_name}' fired.")

async def _listen_for_message():
async for message in pubsub.listen():
if message and message['type'] == 'message':
self.debug(f">> SUB : Received message for event '{event_name}'.")
return

poll_marker = asyncio.wait_for(_poll_marker(), timeout=timeout)
listen_for_message = asyncio.wait_for(_listen_for_message(), timeout=timeout)

try:
tasks = {
asyncio.create_task(poll_marker, name=f'poll_marker_{event_name}_{self.transfer_id}'),
asyncio.create_task(listen_for_message, name=f'listen_for_message_{event_name}_{self.transfer_id}')
}
_, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()

except asyncio.TimeoutError:
with anyio.fail_after(timeout):
async with anyio.create_task_group() as tg:
tg.start_soon(self._poll_marker, event_key)
tg.start_soon(self._listen_for_message, pubsub, event_key)

except TimeoutError:
self.error(f"Timeout waiting for event '{event_name}' after {timeout} seconds.")
for task in tasks:
task.cancel()
raise

finally:
await pubsub.unsubscribe(event_key)
await pubsub.aclose()
await asyncio.gather(*tasks, return_exceptions=True)

## Metadata operations ##

Expand Down
13 changes: 7 additions & 6 deletions lib/transfer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio
import anyio
from starlette.responses import ClientDisconnect
from starlette.websockets import WebSocketDisconnect
from typing import AsyncIterator, Callable, Awaitable, Optional, Any
Expand Down Expand Up @@ -113,8 +113,8 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
self.error(f"△ Unexpected upload error: {e}")
await self.store.put_in_queue(self.DEAD_FLAG)

except asyncio.TimeoutError as e:
self.warning(f"△ Timeout during upload.")
except TimeoutError as e:
self.warning(f"△ Timeout during upload.", exc_info=True)
await on_error("Timeout during upload.")

except TransferError as e:
Expand All @@ -125,7 +125,7 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
await on_error(e)

finally:
await asyncio.sleep(1.0)
await anyio.sleep(1.0)

async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[None]]) -> AsyncIterator[bytes]:
self.bytes_downloaded = 0
Expand Down Expand Up @@ -158,8 +158,9 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[

async def cleanup(self):
try:
await asyncio.wait_for(self.store.cleanup(), timeout=30.0)
except asyncio.TimeoutError:
with anyio.fail_after(30.0):
await self.store.cleanup()
except TimeoutError:
self.warning(f"- Cleanup timed out.")
pass

Expand Down
5 changes: 3 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import anyio
from string import ascii_letters
from itertools import islice, repeat, chain
from typing import Tuple, Iterable, AsyncIterator
from annotated_types import T
import anyio.lowlevel

from lib.metadata import FileMetadata

Expand All @@ -24,4 +25,4 @@ async def chunks(data: bytes, chunk_size: int = 1024) -> AsyncIterator[bytes]:
"""Yield successive chunks of data."""
for i in range(0, len(data), chunk_size):
yield data[i:i + chunk_size]
await asyncio.sleep(0)
await anyio.lowlevel.checkpoint()
38 changes: 19 additions & 19 deletions tests/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio
import anyio
import json
import pytest
import httpx
Expand Down Expand Up @@ -69,7 +69,7 @@ async def test_transfer_id_already_used(websocket_client: WebSocketTestClient):

# # Override the timeout for the test to make it fail quickly
# async def mock_wait_for_client_connected(self):
# await asyncio.sleep(1.0) # Short delay
# await anyio.sleep(1.0) # Short delay
# raise asyncio.TimeoutError("Mocked timeout")

# from lib.transfer import FileTransfer
Expand All @@ -95,32 +95,32 @@ async def test_receiver_disconnects(test_client: httpx.AsyncClient, websocket_cl
async def sender():
with pytest.raises(ConnectionClosedError, match="Transfer was interrupted by the receiver"):
async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

await ws.send_json({
'file_name': file_metadata.name,
'file_size': file_metadata.size,
'file_type': file_metadata.type
})
await asyncio.sleep(1.0) # Allow receiver to connect
await anyio.sleep(1.0) # Allow receiver to connect

response = await ws.recv()
await asyncio.sleep(0.1)
await anyio.sleep(0.1)
assert response == "Go for file chunks"

chunks = [file_content[i:i + 4096] for i in range(0, len(file_content), 4096)]
for chunk in chunks:
await ws.send_bytes(chunk)
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

await asyncio.sleep(2.0)
await anyio.sleep(2.0)

async def receiver():
await asyncio.sleep(1.0)
await anyio.sleep(1.0)
headers = {'Accept': '*/*'}

async with test_client.stream("GET", f"/{uid}?download=true", headers=headers) as response:
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

response.raise_for_status()
i = 0
Expand All @@ -131,11 +131,11 @@ async def receiver():
i += 1
if i >= 5:
raise ClientDisconnect("Simulated disconnect")
await asyncio.sleep(0.025)
await anyio.sleep(0.025)

t1 = asyncio.create_task(asyncio.wait_for(sender(), timeout=15))
t2 = asyncio.create_task(asyncio.wait_for(receiver(), timeout=15))
await asyncio.gather(t1, t2)
async with anyio.create_task_group() as tg:
tg.start_soon(sender)
tg.start_soon(receiver)


@pytest.mark.anyio
Expand All @@ -146,18 +146,18 @@ async def test_prefetcher_request(test_client: httpx.AsyncClient, websocket_clie

# Create a dummy transfer to get metadata
async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

await ws.send_json({
'file_name': file_metadata.name,
'file_size': file_metadata.size,
'file_type': file_metadata.type
})
await asyncio.sleep(1.0)
await anyio.sleep(1.0)

headers = {'User-Agent': 'facebookexternalhit/1.1'}
response = await test_client.get(f"/{uid}", headers=headers)
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

assert response.status_code == 200
assert "text/html" in response.headers['content-type']
Expand All @@ -172,18 +172,18 @@ async def test_browser_download_page(test_client: httpx.AsyncClient, websocket_c
_, file_metadata = generate_test_file()

async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

await ws.send_json({
'file_name': file_metadata.name,
'file_size': file_metadata.size,
'file_type': file_metadata.type
})
await asyncio.sleep(1.0)
await anyio.sleep(1.0)

headers = {'User-Agent': 'Mozilla/5.0'}
response = await test_client.get(f"/{uid}", headers=headers)
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

assert response.status_code == 200
assert "text/html" in response.headers['content-type']
Expand Down
44 changes: 22 additions & 22 deletions tests/test_journeys.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio
import anyio
import httpx
import json
import pytest
Expand All @@ -15,55 +15,55 @@ async def test_websocket_upload_http_download(test_client: httpx.AsyncClient, we

async def sender():
async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

await ws.websocket.send(json.dumps({
'file_name': file_metadata.name,
'file_size': file_metadata.size,
'file_type': file_metadata.type
}))
await asyncio.sleep(1.0)
await anyio.sleep(1.0)

# Wait for receiver to connect
response = await ws.websocket.recv()
await asyncio.sleep(0.1)
await anyio.sleep(0.1)
assert response == "Go for file chunks"

# Send file
chunk_size = 4096
for i in range(0, len(file_content), chunk_size):
await ws.websocket.send(file_content[i:i + chunk_size])
await asyncio.sleep(0.025)
await anyio.sleep(0.025)

await ws.websocket.send(b'') # End of file
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

async def receiver():
await asyncio.sleep(1.0)
await anyio.sleep(1.0)
headers = {'User-Agent': 'Mozilla/5.0', 'Accept': '*/*'}

async with test_client.stream("GET", f"/{uid}?download=true", headers=headers) as response:
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

response.raise_for_status()
assert response.headers['content-length'] == str(file_metadata.size)
assert f"filename={file_metadata.name}" in response.headers['content-disposition']
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

downloaded_content = b''
async for chunk in response.aiter_bytes(4096):
if not chunk or len(downloaded_content) >= file_metadata.size:
break
downloaded_content += chunk
await asyncio.sleep(0.025)
await anyio.sleep(0.025)

assert len(downloaded_content) == file_metadata.size
assert downloaded_content == file_content
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

t1 = asyncio.create_task(asyncio.wait_for(sender(), timeout=15))
t2 = asyncio.create_task(asyncio.wait_for(receiver(), timeout=15))
await asyncio.gather(t1, t2, return_exceptions=True)
async with anyio.create_task_group() as tg:
tg.start_soon(sender)
tg.start_soon(receiver)


@pytest.mark.anyio
Expand All @@ -78,22 +78,22 @@ async def sender():
'Content-Length': str(file_metadata.size)
}
async with test_client.stream("PUT", f"/{uid}/{file_metadata.name}", content=file_content, headers=headers) as response:
await asyncio.sleep(1.0)
await anyio.sleep(1.0)

response.raise_for_status()
assert response.status_code == 200
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

async def receiver():
await asyncio.sleep(1.0)
await anyio.sleep(1.0)
response = await test_client.get(f"/{uid}?download=true")
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

response.raise_for_status()
assert response.content == file_content
assert len(response.content) == file_metadata.size
await asyncio.sleep(0.1)
await anyio.sleep(0.1)

t1 = asyncio.create_task(asyncio.wait_for(sender(), timeout=15))
t2 = asyncio.create_task(asyncio.wait_for(receiver(), timeout=15))
await asyncio.gather(t1, t2, return_exceptions=True)
async with anyio.create_task_group() as tg:
tg.start_soon(sender)
tg.start_soon(receiver)
4 changes: 2 additions & 2 deletions views/http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import string
import asyncio
import anyio
from fastapi import Request, APIRouter
from fastapi.templating import Jinja2Templates
from starlette.background import BackgroundTask
Expand Down Expand Up @@ -55,7 +55,7 @@ async def http_upload(request: Request, uid: str, filename: str):

try:
await transfer.wait_for_client_connected()
except asyncio.TimeoutError:
except TimeoutError:
log.warning("△ Receiver did not connect in time.")
raise HTTPException(status_code=408, detail="Client did not connect in time.")

Expand Down
Loading