Skip to content

Commit 89c556b

Browse files
committed
Big refactoring
1 parent 8e72a6f commit 89c556b

File tree

9 files changed

+325
-161
lines changed

9 files changed

+325
-161
lines changed

app.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import os
2-
import redis
32
import logging
4-
import redis.asyncio
53
import sentry_sdk
4+
import redis.asyncio
65
from fastapi import FastAPI
76
from fastapi.responses import FileResponse
87
from fastapi.staticfiles import StaticFiles
@@ -22,13 +21,13 @@
2221
# Redis
2322
redis_client = redis.asyncio.from_url(os.getenv("REDIS_URL", "redis://localhost:6379"))
2423

24+
# FastAPI
2525
@asynccontextmanager
2626
async def lifespan(app: FastAPI):
2727
setup_logging()
2828
yield
2929
await redis_client.close()
3030

31-
# FastAPI
3231
app = FastAPI(
3332
debug=True,
3433
title="Transit.sh",

lib/callbacks.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from fastapi import HTTPException
2+
from starlette.requests import Request
3+
from starlette.websockets import WebSocket, WebSocketState
4+
from typing import Awaitable, Callable
5+
6+
7+
def send_error_and_close(websocket: WebSocket) -> Callable[[Exception | str], Awaitable[None]]:
8+
"""Callback to send an error message and close the WebSocket connection."""
9+
10+
async def _send_error_and_close(error: Exception | str) -> None:
11+
message = str(error) if isinstance(error, Exception) else error
12+
if websocket.client_state == WebSocketState.CONNECTED:
13+
await websocket.send_text(f"Error: {message}")
14+
await websocket.close(code=1011, reason=message)
15+
16+
return _send_error_and_close
17+
18+
19+
def raise_http_exception(request: Request) -> Callable[[Exception | str], Awaitable[None]]:
20+
"""Callback to raise an HTTPException with a specific status code."""
21+
22+
async def _raise_http_exception(error: Exception | str) -> None:
23+
message = str(error) if isinstance(error, Exception) else error
24+
code = error.status_code if isinstance(error, HTTPException) else 400
25+
if not await request.is_disconnected():
26+
raise HTTPException(status_code=code, detail=message)
27+
28+
return _raise_http_exception

lib/metadata.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import json
2+
from dataclasses import dataclass, asdict
3+
from starlette.datastructures import Headers
4+
from typing import Optional, Self
5+
6+
7+
@dataclass(frozen=True)
8+
class FileMetadata:
9+
size: int
10+
name: str
11+
content_type: Optional[str] = None
12+
13+
def to_json(self) -> str:
14+
return json.dumps(asdict(self), skipkeys=True)
15+
16+
@classmethod
17+
def from_json(cls, data: str) -> Self:
18+
return cls(**json.loads(data))
19+
20+
@classmethod
21+
def get_from_http_headers(cls, headers: Headers, filename: str) -> Self:
22+
return cls(
23+
name=cls.escape_filename(filename),
24+
size=cls.process_length(headers.get('content-length', '0')),
25+
content_type=headers.get('content-type', '')
26+
)
27+
28+
@classmethod
29+
def get_from_json(cls, header: dict) -> Self:
30+
return cls(
31+
name=cls.escape_filename(header['file_name']),
32+
size=cls.process_length(header['file_size']),
33+
content_type=header['file_type']
34+
)
35+
36+
@staticmethod
37+
def escape_filename(filename: str) -> str:
38+
"""Escape special characters in the filename."""
39+
return str(filename).encode('latin-1', 'ignore').decode('utf-8', 'ignore')
40+
41+
@staticmethod
42+
def process_length(length: str | int) -> int:
43+
"""Convert size string to bytes."""
44+
try:
45+
size = int(str(length).strip().replace(' ', ''))
46+
except ValueError:
47+
raise ValueError(f"Invalid size format: {length}")
48+
if size <= 0:
49+
raise ValueError("File size has to be positive.")
50+
return size
51+
52+
def __str__(self):
53+
return f"{self.name} ({self.size/(1024**2):.1f} MiB - {self.content_type})"
54+
55+
def __repr__(self):
56+
return f"FileMetadata(name={self.name!r}, size={self.size/(1024**2):.1f}, content_type={self.content_type!r})"

lib/store.py

Lines changed: 69 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class Store:
1212
Handles data queuing and event signaling for transfer coordination.
1313
"""
1414

15-
_redis: Optional[redis.Redis] = None
15+
redis_client: None | redis.Redis = None
1616

1717
def __init__(self, transfer_id: str):
1818
self.transfer_id = transfer_id
@@ -22,58 +22,32 @@ def __init__(self, transfer_id: str):
2222
self._k_queue = self.key('queue')
2323
self._k_meta = self.key('metadata')
2424
self._k_cleanup = f'cleanup:{transfer_id}'
25-
self._cleanup = False
2625

2726
@classmethod
2827
def get_redis(cls) -> redis.Redis:
2928
"""Get the Redis client instance."""
30-
if cls._redis is None:
29+
if cls.redis_client is None:
3130
from app import redis_client
32-
cls._redis = redis_client
33-
return cls._redis
31+
cls.redis_client = redis_client
32+
return cls.redis_client
3433

35-
def key(self, name: str):
36-
"""Get the Redis key for the provided name with this transfer."""
34+
def key(self, name: str) -> str:
35+
"""Get the Redis key for this transfer with the provided name."""
3736
return f'transfer:{self.transfer_id}:{name}'
3837

39-
async def cleanup_started(self) -> bool:
40-
"""Check if cleanup has been initiated for this transfer."""
41-
challenge = random.randbytes(8)
42-
await self.redis.set(self._k_cleanup, challenge, ex=60, nx=True)
43-
if await self.redis.get(self._k_cleanup) == challenge:
44-
return False
45-
return True
46-
47-
async def cleanup(self) -> int:
48-
"""Remove all keys related to this transfer."""
49-
if await self.cleanup_started():
50-
return 0
51-
52-
pattern = self.key('*')
53-
keys_to_delete = set()
54-
55-
cursor = 0
56-
while True:
57-
cursor, keys = await self.redis.scan(cursor, match=pattern)
58-
keys_to_delete |= set(keys)
59-
if cursor == 0:
60-
break
61-
62-
if keys_to_delete:
63-
return await self.redis.delete(*keys_to_delete)
64-
6538
## Queue operations ##
6639

6740
async def _wait_for_queue_space(self, maxsize: int) -> None:
6841
while await self.redis.llen(self._k_queue) >= maxsize:
6942
await asyncio.sleep(0.5)
7043

71-
async def put_in_queue(self, data: bytes, maxsize: int = 16, timeout: float = 30.0) -> None:
44+
async def put_in_queue(self, data: bytes, maxsize: int = 16, timeout: float = 10.0) -> None:
7245
"""Add data to the transfer queue with backpressure control."""
73-
await asyncio.wait_for(self._wait_for_queue_space(maxsize), timeout=timeout)
46+
async with asyncio.timeout(timeout):
47+
await self._wait_for_queue_space(maxsize)
7448
await self.redis.lpush(self._k_queue, data)
7549

76-
async def get_from_queue(self, timeout: float = 30.0) -> bytes:
50+
async def get_from_queue(self, timeout: float = 10.0) -> bytes:
7751
"""Get data from the transfer queue with timeout."""
7852
result = await self.redis.brpop([self._k_queue], timeout=timeout)
7953
if not result:
@@ -87,15 +61,15 @@ async def get_from_queue(self, timeout: float = 30.0) -> bytes:
8761
async def set_event(self, event_name: str, expiry: float = 300.0) -> None:
8862
"""Set an event flag for this transfer."""
8963
event_key = self.key(event_name)
90-
event_marker_key = self.key(f"{event_name}:marker")
64+
event_marker_key = f'{event_key}:marker'
9165

9266
await self.redis.set(event_marker_key, '1', ex=int(expiry))
9367
await self.redis.publish(event_key, '1')
9468

9569
async def wait_for_event(self, event_name: str, timeout: float = 300.0) -> None:
9670
"""Wait for an event to be set for this transfer."""
9771
event_key = self.key(event_name)
98-
event_marker_key = self.key(f"{event_name}:marker")
72+
event_marker_key = f'{event_key}:marker'
9973
pubsub = self.redis.pubsub(ignore_subscribe_messages=True)
10074
await pubsub.subscribe(event_key)
10175

@@ -135,10 +109,64 @@ async def _listen_for_message():
135109

136110
## Metadata operations ##
137111

138-
async def set_metadata(self, metadata: str, expiry: float = 3600.0) -> None:
112+
async def set_metadata(self, metadata: str) -> None:
139113
"""Store transfer metadata."""
140-
await self.redis.set(self._k_meta, metadata, ex=int(expiry), nx=True)
114+
if int (await self.redis.exists(self._k_meta)) > 0:
115+
raise KeyError(f"Metadata for transfer '{self.transfer_id}' already exists.")
116+
await self.redis.set(self._k_meta, metadata, nx=True)
141117

142-
async def get_metadata(self) -> Optional[str]:
118+
async def get_metadata(self) -> str | None:
143119
"""Retrieve transfer metadata."""
144120
return await self.redis.get(self._k_meta)
121+
122+
## Transfer state operations ##
123+
124+
async def set_completed(self) -> None:
125+
"""Mark the transfer as completed."""
126+
await self.redis.set(f'completed:{self.transfer_id}', '1', ex=300, nx=True)
127+
128+
async def is_completed(self) -> bool:
129+
"""Check if the transfer is marked as completed."""
130+
return await self.redis.exists(f'completed:{self.transfer_id}') > 0
131+
132+
async def set_interrupted(self) -> None:
133+
"""Mark the transfer as interrupted."""
134+
await self.redis.set(f'interrupt:{self.transfer_id}', '1', ex=300, nx=True)
135+
await self.redis.ltrim(self._k_queue, 0, 0)
136+
137+
async def is_interrupted(self) -> bool:
138+
"""Check if the transfer was interrupted."""
139+
return await self.redis.exists(f'interrupt:{self.transfer_id}') > 0
140+
141+
## Cleanup operations ##
142+
143+
async def cleanup_started(self) -> bool:
144+
"""
145+
Check if cleanup has already been initiated for this transfer.
146+
This uses a set/get pattern with challenge to avoid race conditions.
147+
"""
148+
challenge = random.randbytes(8)
149+
await self.redis.set(self._k_cleanup, challenge, ex=60, nx=True)
150+
if await self.redis.get(self._k_cleanup) == challenge:
151+
return False
152+
return True
153+
154+
async def cleanup(self) -> int:
155+
"""Remove all keys related to this transfer."""
156+
if await self.cleanup_started():
157+
return 0
158+
159+
pattern = self.key('*')
160+
keys_to_delete = set()
161+
162+
cursor = 0
163+
while True:
164+
cursor, keys = await self.redis.scan(cursor, match=pattern)
165+
keys_to_delete |= set(keys)
166+
if cursor == 0:
167+
break
168+
169+
if keys_to_delete:
170+
self.log.debug(f"- Cleaning up {len(keys_to_delete)} keys")
171+
return await self.redis.delete(*keys_to_delete)
172+
return 0

0 commit comments

Comments
 (0)