Skip to content

Commit dfe7804

Browse files
committed
Bulk of the work
1 parent 582d745 commit dfe7804

File tree

7 files changed

+79
-87
lines changed

7 files changed

+79
-87
lines changed

lib/store.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import random
2-
import asyncio
2+
import anyio
33
import redis.asyncio as redis
4+
from redis.asyncio.client import PubSub
45
from typing import Optional, Annotated
56

67
from lib.logging import HasLogging, get_logger
@@ -39,19 +40,19 @@ def key(self, name: str) -> str:
3940

4041
async def _wait_for_queue_space(self, maxsize: int) -> None:
4142
while await self.redis.llen(self._k_queue) >= maxsize:
42-
await asyncio.sleep(0.5)
43+
await anyio.sleep(0.5)
4344

4445
async def put_in_queue(self, data: bytes, maxsize: int = 16, timeout: float = 20.0) -> None:
4546
"""Add data to the transfer queue with backpressure control."""
46-
async with asyncio.timeout(timeout):
47+
with anyio.fail_after(timeout):
4748
await self._wait_for_queue_space(maxsize)
4849
await self.redis.lpush(self._k_queue, data)
4950

5051
async def get_from_queue(self, timeout: float = 20.0) -> bytes:
5152
"""Get data from the transfer queue with timeout."""
5253
result = await self.redis.brpop([self._k_queue], timeout=timeout)
5354
if not result:
54-
raise asyncio.TimeoutError("Timeout waiting for data")
55+
raise TimeoutError("Timeout waiting for data")
5556

5657
_, data = result
5758
return data
@@ -66,46 +67,37 @@ async def set_event(self, event_name: str, expiry: float = 300.0) -> None:
6667
await self.redis.set(event_marker_key, '1', ex=int(expiry))
6768
await self.redis.publish(event_key, '1')
6869

70+
async def _poll_marker(self, event_key: str) -> None:
71+
"""Poll for event marker existence."""
72+
event_marker_key = f'{event_key}:marker'
73+
while not await self.redis.exists(event_marker_key):
74+
await anyio.sleep(1)
75+
76+
async def _listen_for_message(self, pubsub: PubSub, event_key: str) -> None:
77+
"""Listen for pubsub messages."""
78+
await pubsub.subscribe(event_key)
79+
async for message in pubsub.listen():
80+
if message and message['type'] == 'message':
81+
return
82+
6983
async def wait_for_event(self, event_name: str, timeout: float = 300.0) -> None:
7084
"""Wait for an event to be set for this transfer."""
7185
event_key = self.key(event_name)
72-
event_marker_key = f'{event_key}:marker'
7386
pubsub = self.redis.pubsub(ignore_subscribe_messages=True)
74-
await pubsub.subscribe(event_key)
75-
76-
async def _poll_marker():
77-
while not await self.redis.exists(event_marker_key):
78-
await asyncio.sleep(1)
79-
self.debug(f">> POLL: Event '{event_name}' fired.")
80-
81-
async def _listen_for_message():
82-
async for message in pubsub.listen():
83-
if message and message['type'] == 'message':
84-
self.debug(f">> SUB : Received message for event '{event_name}'.")
85-
return
86-
87-
poll_marker = asyncio.wait_for(_poll_marker(), timeout=timeout)
88-
listen_for_message = asyncio.wait_for(_listen_for_message(), timeout=timeout)
8987

9088
try:
91-
tasks = {
92-
asyncio.create_task(poll_marker, name=f'poll_marker_{event_name}_{self.transfer_id}'),
93-
asyncio.create_task(listen_for_message, name=f'listen_for_message_{event_name}_{self.transfer_id}')
94-
}
95-
_, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
96-
for task in pending:
97-
task.cancel()
98-
99-
except asyncio.TimeoutError:
89+
with anyio.fail_after(timeout):
90+
async with anyio.create_task_group() as tg:
91+
tg.start_soon(self._poll_marker, event_key)
92+
tg.start_soon(self._listen_for_message, pubsub, event_key)
93+
94+
except TimeoutError:
10095
self.error(f"Timeout waiting for event '{event_name}' after {timeout} seconds.")
101-
for task in tasks:
102-
task.cancel()
10396
raise
10497

10598
finally:
10699
await pubsub.unsubscribe(event_key)
107100
await pubsub.aclose()
108-
await asyncio.gather(*tasks, return_exceptions=True)
109101

110102
## Metadata operations ##
111103

lib/transfer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import asyncio
1+
import anyio
22
from starlette.responses import ClientDisconnect
33
from starlette.websockets import WebSocketDisconnect
44
from typing import AsyncIterator, Callable, Awaitable, Optional
@@ -113,8 +113,8 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
113113
self.error(f"△ Unexpected upload error: {e}")
114114
await self.store.put_in_queue(self.DEAD_FLAG)
115115

116-
except asyncio.TimeoutError as e:
117-
self.warning(f"△ Timeout during upload.")
116+
except TimeoutError as e:
117+
self.warning(f"△ Timeout during upload.", exc_info=True)
118118
await on_error("Timeout during upload.")
119119

120120
except TransferError as e:
@@ -125,7 +125,7 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
125125
await on_error(e)
126126

127127
finally:
128-
await asyncio.sleep(1.0)
128+
await anyio.sleep(1.0)
129129

130130
async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[None]]) -> AsyncIterator[bytes]:
131131
self.bytes_downloaded = 0
@@ -158,8 +158,9 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
158158

159159
async def cleanup(self):
160160
try:
161-
await asyncio.wait_for(self.store.cleanup(), timeout=30.0)
162-
except asyncio.TimeoutError:
161+
with anyio.fail_after(30.0):
162+
await self.store.cleanup()
163+
except TimeoutError:
163164
self.warning(f"- Cleanup timed out.")
164165
pass
165166

tests/helpers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import asyncio
1+
import anyio
22
from string import ascii_letters
33
from itertools import islice, repeat, chain
44
from typing import Tuple, Iterable, AsyncIterator
55
from annotated_types import T
6+
import anyio.lowlevel
67

78
from lib.metadata import FileMetadata
89

@@ -24,4 +25,4 @@ async def chunks(data: bytes, chunk_size: int = 1024) -> AsyncIterator[bytes]:
2425
"""Yield successive chunks of data."""
2526
for i in range(0, len(data), chunk_size):
2627
yield data[i:i + chunk_size]
27-
await asyncio.sleep(0)
28+
await anyio.lowlevel.checkpoint()

tests/test_endpoints.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import asyncio
1+
import anyio
22
import json
33
import pytest
44
import httpx
@@ -69,7 +69,7 @@ async def test_transfer_id_already_used(websocket_client: WebSocketTestClient):
6969

7070
# # Override the timeout for the test to make it fail quickly
7171
# async def mock_wait_for_client_connected(self):
72-
# await asyncio.sleep(1.0) # Short delay
72+
# await anyio.sleep(1.0) # Short delay
7373
# raise asyncio.TimeoutError("Mocked timeout")
7474

7575
# from lib.transfer import FileTransfer
@@ -95,32 +95,32 @@ async def test_receiver_disconnects(test_client: httpx.AsyncClient, websocket_cl
9595
async def sender():
9696
with pytest.raises(ConnectionClosedError, match="Transfer was interrupted by the receiver"):
9797
async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
98-
await asyncio.sleep(0.1)
98+
await anyio.sleep(0.1)
9999

100100
await ws.send_json({
101101
'file_name': file_metadata.name,
102102
'file_size': file_metadata.size,
103103
'file_type': file_metadata.type
104104
})
105-
await asyncio.sleep(1.0) # Allow receiver to connect
105+
await anyio.sleep(1.0) # Allow receiver to connect
106106

107107
response = await ws.recv()
108-
await asyncio.sleep(0.1)
108+
await anyio.sleep(0.1)
109109
assert response == "Go for file chunks"
110110

111111
chunks = [file_content[i:i + 4096] for i in range(0, len(file_content), 4096)]
112112
for chunk in chunks:
113113
await ws.send_bytes(chunk)
114-
await asyncio.sleep(0.1)
114+
await anyio.sleep(0.1)
115115

116-
await asyncio.sleep(2.0)
116+
await anyio.sleep(2.0)
117117

118118
async def receiver():
119-
await asyncio.sleep(1.0)
119+
await anyio.sleep(1.0)
120120
headers = {'Accept': '*/*'}
121121

122122
async with test_client.stream("GET", f"/{uid}?download=true", headers=headers) as response:
123-
await asyncio.sleep(0.1)
123+
await anyio.sleep(0.1)
124124

125125
response.raise_for_status()
126126
i = 0
@@ -131,11 +131,11 @@ async def receiver():
131131
i += 1
132132
if i >= 5:
133133
raise ClientDisconnect("Simulated disconnect")
134-
await asyncio.sleep(0.025)
134+
await anyio.sleep(0.025)
135135

136-
t1 = asyncio.create_task(asyncio.wait_for(sender(), timeout=15))
137-
t2 = asyncio.create_task(asyncio.wait_for(receiver(), timeout=15))
138-
await asyncio.gather(t1, t2)
136+
async with anyio.create_task_group() as tg:
137+
tg.start_soon(sender)
138+
tg.start_soon(receiver)
139139

140140

141141
@pytest.mark.anyio
@@ -146,18 +146,18 @@ async def test_prefetcher_request(test_client: httpx.AsyncClient, websocket_clie
146146

147147
# Create a dummy transfer to get metadata
148148
async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
149-
await asyncio.sleep(0.1)
149+
await anyio.sleep(0.1)
150150

151151
await ws.send_json({
152152
'file_name': file_metadata.name,
153153
'file_size': file_metadata.size,
154154
'file_type': file_metadata.type
155155
})
156-
await asyncio.sleep(1.0)
156+
await anyio.sleep(1.0)
157157

158158
headers = {'User-Agent': 'facebookexternalhit/1.1'}
159159
response = await test_client.get(f"/{uid}", headers=headers)
160-
await asyncio.sleep(0.1)
160+
await anyio.sleep(0.1)
161161

162162
assert response.status_code == 200
163163
assert "text/html" in response.headers['content-type']
@@ -172,18 +172,18 @@ async def test_browser_download_page(test_client: httpx.AsyncClient, websocket_c
172172
_, file_metadata = generate_test_file()
173173

174174
async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
175-
await asyncio.sleep(0.1)
175+
await anyio.sleep(0.1)
176176

177177
await ws.send_json({
178178
'file_name': file_metadata.name,
179179
'file_size': file_metadata.size,
180180
'file_type': file_metadata.type
181181
})
182-
await asyncio.sleep(1.0)
182+
await anyio.sleep(1.0)
183183

184184
headers = {'User-Agent': 'Mozilla/5.0'}
185185
response = await test_client.get(f"/{uid}", headers=headers)
186-
await asyncio.sleep(0.1)
186+
await anyio.sleep(0.1)
187187

188188
assert response.status_code == 200
189189
assert "text/html" in response.headers['content-type']

tests/test_journeys.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import asyncio
1+
import anyio
22
import httpx
33
import json
44
import pytest
@@ -15,55 +15,55 @@ async def test_websocket_upload_http_download(test_client: httpx.AsyncClient, we
1515

1616
async def sender():
1717
async with websocket_client.websocket_connect(f"/send/{uid}") as ws:
18-
await asyncio.sleep(0.1)
18+
await anyio.sleep(0.1)
1919

2020
await ws.websocket.send(json.dumps({
2121
'file_name': file_metadata.name,
2222
'file_size': file_metadata.size,
2323
'file_type': file_metadata.type
2424
}))
25-
await asyncio.sleep(1.0)
25+
await anyio.sleep(1.0)
2626

2727
# Wait for receiver to connect
2828
response = await ws.websocket.recv()
29-
await asyncio.sleep(0.1)
29+
await anyio.sleep(0.1)
3030
assert response == "Go for file chunks"
3131

3232
# Send file
3333
chunk_size = 4096
3434
for i in range(0, len(file_content), chunk_size):
3535
await ws.websocket.send(file_content[i:i + chunk_size])
36-
await asyncio.sleep(0.025)
36+
await anyio.sleep(0.025)
3737

3838
await ws.websocket.send(b'') # End of file
39-
await asyncio.sleep(0.1)
39+
await anyio.sleep(0.1)
4040

4141
async def receiver():
42-
await asyncio.sleep(1.0)
42+
await anyio.sleep(1.0)
4343
headers = {'User-Agent': 'Mozilla/5.0', 'Accept': '*/*'}
4444

4545
async with test_client.stream("GET", f"/{uid}?download=true", headers=headers) as response:
46-
await asyncio.sleep(0.1)
46+
await anyio.sleep(0.1)
4747

4848
response.raise_for_status()
4949
assert response.headers['content-length'] == str(file_metadata.size)
5050
assert f"filename={file_metadata.name}" in response.headers['content-disposition']
51-
await asyncio.sleep(0.1)
51+
await anyio.sleep(0.1)
5252

5353
downloaded_content = b''
5454
async for chunk in response.aiter_bytes(4096):
5555
if not chunk or len(downloaded_content) >= file_metadata.size:
5656
break
5757
downloaded_content += chunk
58-
await asyncio.sleep(0.025)
58+
await anyio.sleep(0.025)
5959

6060
assert len(downloaded_content) == file_metadata.size
6161
assert downloaded_content == file_content
62-
await asyncio.sleep(0.1)
62+
await anyio.sleep(0.1)
6363

64-
t1 = asyncio.create_task(asyncio.wait_for(sender(), timeout=15))
65-
t2 = asyncio.create_task(asyncio.wait_for(receiver(), timeout=15))
66-
await asyncio.gather(t1, t2, return_exceptions=True)
64+
async with anyio.create_task_group() as tg:
65+
tg.start_soon(sender)
66+
tg.start_soon(receiver)
6767

6868

6969
@pytest.mark.anyio
@@ -78,22 +78,22 @@ async def sender():
7878
'Content-Length': str(file_metadata.size)
7979
}
8080
async with test_client.stream("PUT", f"/{uid}/{file_metadata.name}", content=file_content, headers=headers) as response:
81-
await asyncio.sleep(1.0)
81+
await anyio.sleep(1.0)
8282

8383
response.raise_for_status()
8484
assert response.status_code == 200
85-
await asyncio.sleep(0.1)
85+
await anyio.sleep(0.1)
8686

8787
async def receiver():
88-
await asyncio.sleep(1.0)
88+
await anyio.sleep(1.0)
8989
response = await test_client.get(f"/{uid}?download=true")
90-
await asyncio.sleep(0.1)
90+
await anyio.sleep(0.1)
9191

9292
response.raise_for_status()
9393
assert response.content == file_content
9494
assert len(response.content) == file_metadata.size
95-
await asyncio.sleep(0.1)
95+
await anyio.sleep(0.1)
9696

97-
t1 = asyncio.create_task(asyncio.wait_for(sender(), timeout=15))
98-
t2 = asyncio.create_task(asyncio.wait_for(receiver(), timeout=15))
99-
await asyncio.gather(t1, t2, return_exceptions=True)
97+
async with anyio.create_task_group() as tg:
98+
tg.start_soon(sender)
99+
tg.start_soon(receiver)

views/http.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import string
2-
import asyncio
2+
import anyio
33
from fastapi import Request, APIRouter
44
from fastapi.templating import Jinja2Templates
55
from starlette.background import BackgroundTask
@@ -55,7 +55,7 @@ async def http_upload(request: Request, uid: str, filename: str):
5555

5656
try:
5757
await transfer.wait_for_client_connected()
58-
except asyncio.TimeoutError:
58+
except TimeoutError:
5959
log.warning("△ Receiver did not connect in time.")
6060
raise HTTPException(status_code=408, detail="Client did not connect in time.")
6161

0 commit comments

Comments
 (0)