diff --git a/lib/store.py b/lib/store.py index 31ad1e6..a62d7c4 100644 --- a/lib/store.py +++ b/lib/store.py @@ -1,6 +1,7 @@ import random -import asyncio +import anyio import redis.asyncio as redis +from redis.asyncio.client import PubSub from typing import Optional, Annotated from lib.logging import HasLogging, get_logger @@ -39,11 +40,11 @@ def key(self, name: str) -> str: async def _wait_for_queue_space(self, maxsize: int) -> None: while await self.redis.llen(self._k_queue) >= maxsize: - await asyncio.sleep(0.5) + await anyio.sleep(0.5) async def put_in_queue(self, data: bytes, maxsize: int = 16, timeout: float = 20.0) -> None: """Add data to the transfer queue with backpressure control.""" - async with asyncio.timeout(timeout): + with anyio.fail_after(timeout): await self._wait_for_queue_space(maxsize) await self.redis.lpush(self._k_queue, data) @@ -51,7 +52,7 @@ async def get_from_queue(self, timeout: float = 20.0) -> bytes: """Get data from the transfer queue with timeout.""" result = await self.redis.brpop([self._k_queue], timeout=timeout) if not result: - raise asyncio.TimeoutError("Timeout waiting for data") + raise TimeoutError("Timeout waiting for data") _, data = result return data @@ -66,46 +67,37 @@ async def set_event(self, event_name: str, expiry: float = 300.0) -> None: await self.redis.set(event_marker_key, '1', ex=int(expiry)) await self.redis.publish(event_key, '1') + async def _poll_marker(self, event_key: str) -> None: + """Poll for event marker existence.""" + event_marker_key = f'{event_key}:marker' + while not await self.redis.exists(event_marker_key): + await anyio.sleep(1) + + async def _listen_for_message(self, pubsub: PubSub, event_key: str) -> None: + """Listen for pubsub messages.""" + await pubsub.subscribe(event_key) + async for message in pubsub.listen(): + if message and message['type'] == 'message': + return + async def wait_for_event(self, event_name: str, timeout: float = 300.0) -> None: """Wait for an event to be set for this transfer.""" event_key = self.key(event_name) - event_marker_key = f'{event_key}:marker' pubsub = self.redis.pubsub(ignore_subscribe_messages=True) - await pubsub.subscribe(event_key) - - async def _poll_marker(): - while not await self.redis.exists(event_marker_key): - await asyncio.sleep(1) - self.debug(f">> POLL: Event '{event_name}' fired.") - - async def _listen_for_message(): - async for message in pubsub.listen(): - if message and message['type'] == 'message': - self.debug(f">> SUB : Received message for event '{event_name}'.") - return - - poll_marker = asyncio.wait_for(_poll_marker(), timeout=timeout) - listen_for_message = asyncio.wait_for(_listen_for_message(), timeout=timeout) try: - tasks = { - asyncio.create_task(poll_marker, name=f'poll_marker_{event_name}_{self.transfer_id}'), - asyncio.create_task(listen_for_message, name=f'listen_for_message_{event_name}_{self.transfer_id}') - } - _, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - for task in pending: - task.cancel() - - except asyncio.TimeoutError: + with anyio.fail_after(timeout): + async with anyio.create_task_group() as tg: + tg.start_soon(self._poll_marker, event_key) + tg.start_soon(self._listen_for_message, pubsub, event_key) + + except TimeoutError: self.error(f"Timeout waiting for event '{event_name}' after {timeout} seconds.") - for task in tasks: - task.cancel() raise finally: await pubsub.unsubscribe(event_key) await pubsub.aclose() - await asyncio.gather(*tasks, return_exceptions=True) ## Metadata operations ## diff --git a/lib/transfer.py b/lib/transfer.py index 51f4af9..d8af9d8 100644 --- a/lib/transfer.py +++ b/lib/transfer.py @@ -1,4 +1,4 @@ -import asyncio +import anyio from starlette.responses import ClientDisconnect from starlette.websockets import WebSocketDisconnect from typing import AsyncIterator, Callable, Awaitable, Optional, Any @@ -113,8 +113,8 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[ self.error(f"△ Unexpected upload error: {e}") await self.store.put_in_queue(self.DEAD_FLAG) - except asyncio.TimeoutError as e: - self.warning(f"△ Timeout during upload.") + except TimeoutError as e: + self.warning(f"△ Timeout during upload.", exc_info=True) await on_error("Timeout during upload.") except TransferError as e: @@ -125,7 +125,7 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[ await on_error(e) finally: - await asyncio.sleep(1.0) + await anyio.sleep(1.0) async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[None]]) -> AsyncIterator[bytes]: self.bytes_downloaded = 0 @@ -158,8 +158,9 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[ async def cleanup(self): try: - await asyncio.wait_for(self.store.cleanup(), timeout=30.0) - except asyncio.TimeoutError: + with anyio.fail_after(30.0): + await self.store.cleanup() + except TimeoutError: self.warning(f"- Cleanup timed out.") pass diff --git a/tests/helpers.py b/tests/helpers.py index b653e42..32f4811 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,8 +1,9 @@ -import asyncio +import anyio from string import ascii_letters from itertools import islice, repeat, chain from typing import Tuple, Iterable, AsyncIterator from annotated_types import T +import anyio.lowlevel from lib.metadata import FileMetadata @@ -24,4 +25,4 @@ async def chunks(data: bytes, chunk_size: int = 1024) -> AsyncIterator[bytes]: """Yield successive chunks of data.""" for i in range(0, len(data), chunk_size): yield data[i:i + chunk_size] - await asyncio.sleep(0) + await anyio.lowlevel.checkpoint() diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index d6ec73b..9495c26 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,4 +1,4 @@ -import asyncio +import anyio import json import pytest import httpx @@ -69,7 +69,7 @@ async def test_transfer_id_already_used(websocket_client: WebSocketTestClient): # # Override the timeout for the test to make it fail quickly # async def mock_wait_for_client_connected(self): -# await asyncio.sleep(1.0) # Short delay +# await anyio.sleep(1.0) # Short delay # raise asyncio.TimeoutError("Mocked timeout") # from lib.transfer import FileTransfer @@ -95,32 +95,32 @@ async def test_receiver_disconnects(test_client: httpx.AsyncClient, websocket_cl async def sender(): with pytest.raises(ConnectionClosedError, match="Transfer was interrupted by the receiver"): async with websocket_client.websocket_connect(f"/send/{uid}") as ws: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) await ws.send_json({ 'file_name': file_metadata.name, 'file_size': file_metadata.size, 'file_type': file_metadata.type }) - await asyncio.sleep(1.0) # Allow receiver to connect + await anyio.sleep(1.0) # Allow receiver to connect response = await ws.recv() - await asyncio.sleep(0.1) + await anyio.sleep(0.1) assert response == "Go for file chunks" chunks = [file_content[i:i + 4096] for i in range(0, len(file_content), 4096)] for chunk in chunks: await ws.send_bytes(chunk) - await asyncio.sleep(0.1) + await anyio.sleep(0.1) - await asyncio.sleep(2.0) + await anyio.sleep(2.0) async def receiver(): - await asyncio.sleep(1.0) + await anyio.sleep(1.0) headers = {'Accept': '*/*'} async with test_client.stream("GET", f"/{uid}?download=true", headers=headers) as response: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) response.raise_for_status() i = 0 @@ -131,11 +131,11 @@ async def receiver(): i += 1 if i >= 5: raise ClientDisconnect("Simulated disconnect") - await asyncio.sleep(0.025) + await anyio.sleep(0.025) - t1 = asyncio.create_task(asyncio.wait_for(sender(), timeout=15)) - t2 = asyncio.create_task(asyncio.wait_for(receiver(), timeout=15)) - await asyncio.gather(t1, t2) + async with anyio.create_task_group() as tg: + tg.start_soon(sender) + tg.start_soon(receiver) @pytest.mark.anyio @@ -146,18 +146,18 @@ async def test_prefetcher_request(test_client: httpx.AsyncClient, websocket_clie # Create a dummy transfer to get metadata async with websocket_client.websocket_connect(f"/send/{uid}") as ws: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) await ws.send_json({ 'file_name': file_metadata.name, 'file_size': file_metadata.size, 'file_type': file_metadata.type }) - await asyncio.sleep(1.0) + await anyio.sleep(1.0) headers = {'User-Agent': 'facebookexternalhit/1.1'} response = await test_client.get(f"/{uid}", headers=headers) - await asyncio.sleep(0.1) + await anyio.sleep(0.1) assert response.status_code == 200 assert "text/html" in response.headers['content-type'] @@ -172,18 +172,18 @@ async def test_browser_download_page(test_client: httpx.AsyncClient, websocket_c _, file_metadata = generate_test_file() async with websocket_client.websocket_connect(f"/send/{uid}") as ws: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) await ws.send_json({ 'file_name': file_metadata.name, 'file_size': file_metadata.size, 'file_type': file_metadata.type }) - await asyncio.sleep(1.0) + await anyio.sleep(1.0) headers = {'User-Agent': 'Mozilla/5.0'} response = await test_client.get(f"/{uid}", headers=headers) - await asyncio.sleep(0.1) + await anyio.sleep(0.1) assert response.status_code == 200 assert "text/html" in response.headers['content-type'] diff --git a/tests/test_journeys.py b/tests/test_journeys.py index 5e28e13..4b47c80 100644 --- a/tests/test_journeys.py +++ b/tests/test_journeys.py @@ -1,4 +1,4 @@ -import asyncio +import anyio import httpx import json import pytest @@ -15,55 +15,55 @@ async def test_websocket_upload_http_download(test_client: httpx.AsyncClient, we async def sender(): async with websocket_client.websocket_connect(f"/send/{uid}") as ws: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) await ws.websocket.send(json.dumps({ 'file_name': file_metadata.name, 'file_size': file_metadata.size, 'file_type': file_metadata.type })) - await asyncio.sleep(1.0) + await anyio.sleep(1.0) # Wait for receiver to connect response = await ws.websocket.recv() - await asyncio.sleep(0.1) + await anyio.sleep(0.1) assert response == "Go for file chunks" # Send file chunk_size = 4096 for i in range(0, len(file_content), chunk_size): await ws.websocket.send(file_content[i:i + chunk_size]) - await asyncio.sleep(0.025) + await anyio.sleep(0.025) await ws.websocket.send(b'') # End of file - await asyncio.sleep(0.1) + await anyio.sleep(0.1) async def receiver(): - await asyncio.sleep(1.0) + await anyio.sleep(1.0) headers = {'User-Agent': 'Mozilla/5.0', 'Accept': '*/*'} async with test_client.stream("GET", f"/{uid}?download=true", headers=headers) as response: - await asyncio.sleep(0.1) + await anyio.sleep(0.1) response.raise_for_status() assert response.headers['content-length'] == str(file_metadata.size) assert f"filename={file_metadata.name}" in response.headers['content-disposition'] - await asyncio.sleep(0.1) + await anyio.sleep(0.1) downloaded_content = b'' async for chunk in response.aiter_bytes(4096): if not chunk or len(downloaded_content) >= file_metadata.size: break downloaded_content += chunk - await asyncio.sleep(0.025) + await anyio.sleep(0.025) assert len(downloaded_content) == file_metadata.size assert downloaded_content == file_content - await asyncio.sleep(0.1) + await anyio.sleep(0.1) - t1 = asyncio.create_task(asyncio.wait_for(sender(), timeout=15)) - t2 = asyncio.create_task(asyncio.wait_for(receiver(), timeout=15)) - await asyncio.gather(t1, t2, return_exceptions=True) + async with anyio.create_task_group() as tg: + tg.start_soon(sender) + tg.start_soon(receiver) @pytest.mark.anyio @@ -78,22 +78,22 @@ async def sender(): 'Content-Length': str(file_metadata.size) } async with test_client.stream("PUT", f"/{uid}/{file_metadata.name}", content=file_content, headers=headers) as response: - await asyncio.sleep(1.0) + await anyio.sleep(1.0) response.raise_for_status() assert response.status_code == 200 - await asyncio.sleep(0.1) + await anyio.sleep(0.1) async def receiver(): - await asyncio.sleep(1.0) + await anyio.sleep(1.0) response = await test_client.get(f"/{uid}?download=true") - await asyncio.sleep(0.1) + await anyio.sleep(0.1) response.raise_for_status() assert response.content == file_content assert len(response.content) == file_metadata.size - await asyncio.sleep(0.1) + await anyio.sleep(0.1) - t1 = asyncio.create_task(asyncio.wait_for(sender(), timeout=15)) - t2 = asyncio.create_task(asyncio.wait_for(receiver(), timeout=15)) - await asyncio.gather(t1, t2, return_exceptions=True) + async with anyio.create_task_group() as tg: + tg.start_soon(sender) + tg.start_soon(receiver) diff --git a/views/http.py b/views/http.py index 4ad425d..64b2402 100644 --- a/views/http.py +++ b/views/http.py @@ -1,5 +1,5 @@ import string -import asyncio +import anyio from fastapi import Request, APIRouter from fastapi.templating import Jinja2Templates from starlette.background import BackgroundTask @@ -55,7 +55,7 @@ async def http_upload(request: Request, uid: str, filename: str): try: await transfer.wait_for_client_connected() - except asyncio.TimeoutError: + except TimeoutError: log.warning("△ Receiver did not connect in time.") raise HTTPException(status_code=408, detail="Client did not connect in time.") diff --git a/views/websockets.py b/views/websockets.py index 6912acd..33656fb 100644 --- a/views/websockets.py +++ b/views/websockets.py @@ -1,9 +1,7 @@ import string -import asyncio import warnings -from fastapi import WebSocket, APIRouter, WebSocketDisconnect, BackgroundTasks -from fastapi.responses import PlainTextResponse from pydantic import ValidationError +from fastapi import WebSocket, APIRouter, WebSocketDisconnect, BackgroundTasks from lib.logging import get_logger from lib.callbacks import send_error_and_close @@ -57,7 +55,7 @@ async def websocket_upload(websocket: WebSocket, uid: str): try: await transfer.wait_for_client_connected() - except asyncio.TimeoutError: + except TimeoutError: log.warning("△ Receiver did not connect in time.") await websocket.send_text(f"Error: Receiver did not connect in time.") return