Skip to content

Commit 836fd2a

Browse files
authored
Merge pull request #7 from codeSamuraii/new-test-client
New "live" integration test client + misc
2 parents 85104e9 + 2d71e5c commit 836fd2a

File tree

13 files changed

+388
-136
lines changed

13 files changed

+388
-136
lines changed

.github/workflows/unit-tests.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,8 @@ jobs:
2727
pip install -r requirements.txt
2828
pip install pytest fakeredis
2929
30+
- name: Redis Server in GitHub Actions
31+
uses: supercharge/[email protected]
32+
3033
- name: Run tests
31-
run: pytest --disable-pytest-warnings -v tests/
34+
run: pytest -s -v --tb=short tests/

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ Dockerfile
1414
fly.toml
1515
.pytest_cache/
1616
*.instructions.md
17+
CLAUDE.md
18+
poetry.lock
19+
pyproject.toml

lib/callbacks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@ async def _send_error_and_close(error: Exception | str) -> None:
1616
return _send_error_and_close
1717

1818

19+
class StreamTerminated(Exception):
20+
"""Raised to terminate a stream when an error occurs after the response has started."""
21+
pass
22+
1923
def raise_http_exception(request: Request) -> Callable[[Exception | str], Awaitable[None]]:
2024
"""Callback to raise an HTTPException with a specific status code."""
2125

2226
async def _raise_http_exception(error: Exception | str) -> None:
2327
message = str(error) if isinstance(error, Exception) else error
2428
code = error.status_code if isinstance(error, HTTPException) else 400
25-
if not await request.is_disconnected():
26-
raise HTTPException(status_code=code, detail=message)
29+
raise StreamTerminated(f"{code}: {message}") from error
2730

2831
return _raise_http_exception

lib/logging.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import logging
22

33

4+
# ---- FORMATTING ----
5+
46
class ColoredFormatter(logging.Formatter):
57
"""Custom formatter with colored output and specific formatting for transfers."""
68

@@ -59,12 +61,13 @@ def format(self, record: logging.LogRecord) -> str:
5961

6062
return result + exc_text
6163

62-
6364
class HealthCheckFilter(logging.Filter):
6465
def filter(self, record):
6566
return '"GET /health HTTP/1.1" 200' not in record.getMessage()
6667

6768

69+
# ---- LOGGING SETUP ----
70+
6871
def setup_logging() -> logging.Logger:
6972
"""Configure all loggers to use our custom ColoredFormatter."""
7073
formatter = ColoredFormatter(
@@ -96,7 +99,6 @@ def setup_logging() -> logging.Logger:
9699

97100
return root_logger
98101

99-
100102
def get_logger(logger_name: str) -> logging.Logger:
101103
console_handler = logging.StreamHandler()
102104
console_handler.setLevel(logging.DEBUG)
@@ -114,3 +116,38 @@ def get_logger(logger_name: str) -> logging.Logger:
114116
logger.propagate = False
115117

116118
return logger
119+
120+
121+
# ---- PATCHING ----
122+
123+
class HasLogging(type):
124+
"""Metaclass that automatically adds logging methods and a logger property."""
125+
126+
def __new__(mcs, name, bases, namespace, **kwargs):
127+
name_from = kwargs.get('name_from', 'name')
128+
129+
@property
130+
def logger(self):
131+
if not hasattr(self, '_logger'):
132+
class_name = self.__class__.__name__
133+
fallback_name = class_name + str(id(self))[-4:]
134+
logger_name = getattr(self, name_from, fallback_name)
135+
self._logger = get_logger(logger_name)
136+
if not hasattr(self, name_from):
137+
self._logger.warning(
138+
f"Object {class_name} does not have attribute '{name_from}', "
139+
f"using default name: {logger_name}"
140+
)
141+
return self._logger
142+
143+
namespace['logger'] = logger
144+
145+
def make_log_method(level):
146+
def log_method(self, msg, *args, **kwargs):
147+
getattr(self.logger, level)(msg, *args, **kwargs)
148+
return log_method
149+
150+
for level in {'debug', 'info', 'warning', 'error', 'exception', 'critical'}:
151+
namespace[level] = make_log_method(level)
152+
153+
return super().__new__(mcs, name, bases, namespace)

lib/store.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import redis.asyncio as redis
44
from typing import Optional, Annotated
55

6-
from lib.logging import get_logger
6+
from lib.logging import HasLogging, get_logger
77

88

9-
class Store:
9+
class Store(metaclass=HasLogging, name_from='transfer_id'):
1010
"""
1111
Redis-based store for file transfer queues and events.
1212
Handles data queuing and event signaling for transfer coordination.
@@ -17,7 +17,6 @@ class Store:
1717
def __init__(self, transfer_id: str):
1818
self.transfer_id = transfer_id
1919
self.redis = self.get_redis()
20-
self.log = get_logger(transfer_id)
2120

2221
self._k_queue = self.key('queue')
2322
self._k_meta = self.key('metadata')
@@ -42,13 +41,13 @@ async def _wait_for_queue_space(self, maxsize: int) -> None:
4241
while await self.redis.llen(self._k_queue) >= maxsize:
4342
await asyncio.sleep(0.5)
4443

45-
async def put_in_queue(self, data: bytes, maxsize: int = 16, timeout: float = 10.0) -> None:
44+
async def put_in_queue(self, data: bytes, maxsize: int = 16, timeout: float = 20.0) -> None:
4645
"""Add data to the transfer queue with backpressure control."""
4746
async with asyncio.timeout(timeout):
4847
await self._wait_for_queue_space(maxsize)
4948
await self.redis.lpush(self._k_queue, data)
5049

51-
async def get_from_queue(self, timeout: float = 10.0) -> bytes:
50+
async def get_from_queue(self, timeout: float = 20.0) -> bytes:
5251
"""Get data from the transfer queue with timeout."""
5352
result = await self.redis.brpop([self._k_queue], timeout=timeout)
5453
if not result:
@@ -77,12 +76,12 @@ async def wait_for_event(self, event_name: str, timeout: float = 300.0) -> None:
7776
async def _poll_marker():
7877
while not await self.redis.exists(event_marker_key):
7978
await asyncio.sleep(1)
80-
self.log.debug(f">> POLL: Event '{event_name}' fired.")
79+
self.debug(f">> POLL: Event '{event_name}' fired.")
8180

8281
async def _listen_for_message():
8382
async for message in pubsub.listen():
8483
if message and message['type'] == 'message':
85-
self.log.debug(f">> SUB : Received message for event '{event_name}'.")
84+
self.debug(f">> SUB : Received message for event '{event_name}'.")
8685
return
8786

8887
poll_marker = asyncio.wait_for(_poll_marker(), timeout=timeout)
@@ -98,7 +97,7 @@ async def _listen_for_message():
9897
task.cancel()
9998

10099
except asyncio.TimeoutError:
101-
self.log.error(f"Timeout waiting for event '{event_name}' after {timeout} seconds.")
100+
self.error(f"Timeout waiting for event '{event_name}' after {timeout} seconds.")
102101
for task in tasks:
103102
task.cancel()
104103
raise
@@ -112,9 +111,12 @@ async def _listen_for_message():
112111

113112
async def set_metadata(self, metadata: str) -> None:
114113
"""Store transfer metadata."""
115-
if int (await self.redis.exists(self._k_meta)) > 0:
116-
raise KeyError(f"Metadata for transfer '{self.transfer_id}' already exists.")
117-
await self.redis.set(self._k_meta, metadata, nx=True)
114+
challenge = random.randbytes(8)
115+
await self.redis.set(self._k_meta, challenge, nx=True)
116+
if await self.redis.get(self._k_meta) == challenge:
117+
await self.redis.set(self._k_meta, metadata, ex=300)
118+
else:
119+
raise KeyError("Metadata already set for this transfer.")
118120

119121
async def get_metadata(self) -> str | None:
120122
"""Retrieve transfer metadata."""
@@ -179,6 +181,6 @@ async def cleanup(self) -> int:
179181
break
180182

181183
if keys_to_delete:
182-
self.log.debug(f"- Cleaning up {len(keys_to_delete)} keys")
184+
self.debug(f"- Cleaning up {len(keys_to_delete)} keys")
183185
return await self.redis.delete(*keys_to_delete)
184186
return 0

lib/transfer.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,29 @@
11
import asyncio
22
from starlette.responses import ClientDisconnect
33
from starlette.websockets import WebSocketDisconnect
4-
from typing import AsyncIterator, Callable, Awaitable
4+
from typing import AsyncIterator, Callable, Awaitable, Optional, Any
55

66
from lib.store import Store
7-
from lib.logging import get_logger
87
from lib.metadata import FileMetadata
8+
from lib.logging import HasLogging, get_logger
9+
logger = get_logger('transfer')
910

1011

11-
class FileTransfer:
12+
class TransferError(Exception):
13+
"""Custom exception for transfer errors with optional propagation control."""
14+
def __init__(self, *args, propagate: bool = False, **extra: Any) -> None:
15+
super().__init__(*args)
16+
self.propagate = propagate
17+
self.extra = extra
18+
19+
@property
20+
def shutdown(self) -> bool:
21+
"""Indicates if the transfer should be shut down (usually the opposite of `propagate`)."""
22+
return self.extra.get('shutdown', not self.propagate)
23+
24+
25+
class FileTransfer(metaclass=HasLogging, name_from='uid'):
26+
"""Handles file transfers, including metadata queries and data streaming."""
1227

1328
DONE_FLAG = b'\x00\xFF'
1429
DEAD_FLAG = b'\xDE\xAD'
@@ -20,9 +35,6 @@ def __init__(self, uid: str, file: FileMetadata):
2035
self.bytes_uploaded = 0
2136
self.bytes_downloaded = 0
2237

23-
log = get_logger(self.uid)
24-
self.debug, self.info, self.warning, self.error = log.debug, log.info, log.warning, log.error
25-
2638
@classmethod
2739
async def create(cls, uid: str, file: FileMetadata):
2840
transfer = cls(uid, file)
@@ -86,27 +98,33 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
8698
break
8799

88100
if await self.is_interrupted():
89-
raise ClientDisconnect("Transfer was interrupted by the receiver.")
101+
raise TransferError("Transfer was interrupted by the receiver.", propagate=False)
90102

91103
await self.store.put_in_queue(chunk)
92104
self.bytes_uploaded += len(chunk)
93105

94106
if self.bytes_uploaded < self.file.size:
95-
raise ClientDisconnect("Received less data than expected.")
107+
raise TransferError("Received less data than expected.", propagate=True)
96108

97109
self.debug(f"△ End of upload, sending done marker.")
98110
await self.store.put_in_queue(self.DONE_FLAG)
99111

100112
except (ClientDisconnect, WebSocketDisconnect) as e:
101-
self.warning(f"△ Upload error: {str(e)}")
113+
self.error(f"△ Unexpected upload error: {e}")
102114
await self.store.put_in_queue(self.DEAD_FLAG)
103-
await on_error(e)
104115

105116
except asyncio.TimeoutError as e:
106117
self.warning(f"△ Timeout during upload.")
107118
await on_error("Timeout during upload.")
108119

109-
else:
120+
except TransferError as e:
121+
self.warning(f"△ Upload error: {e}")
122+
if e.propagate:
123+
await self.store.put_in_queue(self.DEAD_FLAG)
124+
else:
125+
await on_error(e)
126+
127+
finally:
110128
await asyncio.sleep(1.0)
111129

112130
async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[None]]) -> AsyncIterator[bytes]:
@@ -117,10 +135,10 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
117135
chunk = await self.store.get_from_queue()
118136

119137
if chunk == self.DEAD_FLAG:
120-
raise ClientDisconnect("Sender disconnected.")
138+
raise TransferError("Sender disconnected.")
121139

122140
if chunk == self.DONE_FLAG and self.bytes_downloaded < self.file.size:
123-
raise ClientDisconnect("Received less data than expected.")
141+
raise TransferError("Received less data than expected.")
124142

125143
elif chunk == self.DONE_FLAG:
126144
self.debug(f"▼ Done marker received, ending download.")
@@ -129,16 +147,14 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
129147
self.bytes_downloaded += len(chunk)
130148
yield chunk
131149

132-
except (ClientDisconnect, WebSocketDisconnect) as e:
133-
self.warning(f"▼ Download error: {e}")
134-
await self.set_interrupted()
135-
136-
except asyncio.TimeoutError:
137-
self.warning(f"▼ Timeout during download.")
138-
await on_error("Timeout during download.")
150+
except Exception as e:
151+
self.error(f"▼ Unexpected download error!", exc_info=True)
152+
self.debug("Debug info:", stack_info=True)
153+
await on_error(e)
139154

140-
else:
141-
await asyncio.sleep(1.0)
155+
except TransferError as e:
156+
self.warning(f"▼ Download error")
157+
await on_error(e)
142158

143159
async def cleanup(self):
144160
try:
@@ -148,10 +164,15 @@ async def cleanup(self):
148164
pass
149165

150166
async def finalize_download(self):
151-
self.debug("▼ Finalizing download...")
167+
# self.debug("▼ Finalizing download...")
168+
if self.bytes_downloaded < self.file.size and not await self.is_interrupted():
169+
self.warning("▼ Client disconnected before download was complete.")
170+
await self.set_interrupted()
171+
172+
await self.cleanup()
173+
# self.debug("▼ Finalizing download...")
152174
if self.bytes_downloaded < self.file.size and not await self.is_interrupted():
153175
self.warning("▼ Client disconnected before download was complete.")
154176
await self.set_interrupted()
155177

156-
await asyncio.sleep(4.0)
157178
await self.cleanup()

static/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ <h3>Using cURL</h3>
6969
<p><small>The <code>-JLO</code> flags downloads the file with its original name and follows redirects.</small></p>
7070
<div class="code-block">
7171
<code><span class="code-comment"># Example</span>
72-
curl -T <span class="code-variable">/music/song.mp3</span> https://transit.sh/<span class="code-string">music-for-dad</span>/
72+
curl -T <span class="code-variable">/music/song.mp3</span> https://transit.sh/<span class="code-string">music-for-dad</span>/ --expect100-timeout 300
7373
curl -JLO https://transit.sh/<span class="code-string">music-for-dad</span>/</code>
7474
</div>
7575
</div>

0 commit comments

Comments
 (0)