Skip to content

Commit 2d71e5c

Browse files
committed
More verbose GA tests
1 parent 927d3ae commit 2d71e5c

File tree

3 files changed

+29
-34
lines changed

3 files changed

+29
-34
lines changed

.github/workflows/unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ jobs:
3131
uses: supercharge/[email protected]
3232

3333
- name: Run tests
34-
run: pytest --disable-pytest-warnings -v tests/
34+
run: pytest -s -v --tb=short tests/

lib/transfer.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,25 @@
11
import asyncio
22
from starlette.responses import ClientDisconnect
33
from starlette.websockets import WebSocketDisconnect
4-
from typing import AsyncIterator, Callable, Awaitable, Optional
4+
from typing import AsyncIterator, Callable, Awaitable, Optional, Any
55

66
from lib.store import Store
77
from lib.metadata import FileMetadata
88
from lib.logging import HasLogging, get_logger
9-
109
logger = get_logger('transfer')
1110

1211

13-
class FileTransferError(Exception):
14-
"""Base class for file transfer errors."""
15-
def __init__(self, *args, **kwargs):
16-
self.args = args
17-
self.kwargs = kwargs
12+
class TransferError(Exception):
13+
"""Custom exception for transfer errors with optional propagation control."""
14+
def __init__(self, *args, propagate: bool = False, **extra: Any) -> None:
1815
super().__init__(*args)
19-
for name, value in kwargs.items():
20-
setattr(self, name, value)
21-
22-
def __str__(self):
23-
kwargs_str = ', '.join(f"{k}={v}" for k, v in self.kwargs.items())
24-
details = f" - {kwargs_str}" if kwargs_str else ''
25-
return super().__str__() + details
16+
self.propagate = propagate
17+
self.extra = extra
2618

27-
def __repr__(self):
28-
return self.__class__.__name__ + f"({', '.join(map(repr, self.args))}, {self.kwargs})"
19+
@property
20+
def shutdown(self) -> bool:
21+
"""Indicates if the transfer should be shut down (usually the opposite of `propagate`)."""
22+
return self.extra.get('shutdown', not self.propagate)
2923

3024

3125
class FileTransfer(metaclass=HasLogging, name_from='uid'):
@@ -104,13 +98,13 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
10498
break
10599

106100
if await self.is_interrupted():
107-
raise FileTransferError("Transfer was interrupted by the receiver.", propagate=False, shutdown=True)
101+
raise TransferError("Transfer was interrupted by the receiver.", propagate=False)
108102

109103
await self.store.put_in_queue(chunk)
110104
self.bytes_uploaded += len(chunk)
111105

112106
if self.bytes_uploaded < self.file.size:
113-
raise FileTransferError("Received less data than expected.", propagate=True, shutdown=False)
107+
raise TransferError("Received less data than expected.", propagate=True)
114108

115109
self.debug(f"△ End of upload, sending done marker.")
116110
await self.store.put_in_queue(self.DONE_FLAG)
@@ -123,11 +117,11 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
123117
self.warning(f"△ Timeout during upload.")
124118
await on_error("Timeout during upload.")
125119

126-
except FileTransferError as e:
120+
except TransferError as e:
127121
self.warning(f"△ Upload error: {e}")
128122
if e.propagate:
129123
await self.store.put_in_queue(self.DEAD_FLAG)
130-
if e.shutdown:
124+
else:
131125
await on_error(e)
132126

133127
finally:
@@ -141,10 +135,10 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
141135
chunk = await self.store.get_from_queue()
142136

143137
if chunk == self.DEAD_FLAG:
144-
raise FileTransferError("Sender disconnected.")
138+
raise TransferError("Sender disconnected.")
145139

146140
if chunk == self.DONE_FLAG and self.bytes_downloaded < self.file.size:
147-
raise FileTransferError("Received less data than expected.")
141+
raise TransferError("Received less data than expected.")
148142

149143
elif chunk == self.DONE_FLAG:
150144
self.debug(f"▼ Done marker received, ending download.")
@@ -153,17 +147,13 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
153147
self.bytes_downloaded += len(chunk)
154148
yield chunk
155149

156-
except (ClientDisconnect, WebSocketDisconnect) as e:
157-
self.error(f"▼ Unexpected download error: {e}")
158-
await on_error(e)
159-
160-
except asyncio.TimeoutError as e:
161-
self.warning(f"▼ Timeout during download.")
162-
self.debug("Debug info:", exc_info=e, stack_info=True)
150+
except Exception as e:
151+
self.error(f"▼ Unexpected download error!", exc_info=True)
152+
self.debug("Debug info:", stack_info=True)
163153
await on_error(e)
164154

165-
except FileTransferError as e:
166-
self.warning(f"▼ Download error: {e}")
155+
except TransferError as e:
156+
self.warning(f"▼ Download error")
167157
await on_error(e)
168158

169159
async def cleanup(self):
@@ -179,5 +169,10 @@ async def finalize_download(self):
179169
self.warning("▼ Client disconnected before download was complete.")
180170
await self.set_interrupted()
181171

182-
await asyncio.sleep(2.0)
172+
await self.cleanup()
173+
# self.debug("▼ Finalizing download...")
174+
if self.bytes_downloaded < self.file.size and not await self.is_interrupted():
175+
self.warning("▼ Client disconnected before download was complete.")
176+
await self.set_interrupted()
177+
183178
await self.cleanup()

views/websockets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ async def websocket_upload(websocket: WebSocket, uid: str):
2222
Then, the client must wait for the signal before sending file chunks.
2323
"""
2424
if any(char not in string.ascii_letters + string.digits + '-' for char in uid):
25-
log.debug(f"△ Invalid transfer ID: {uid}")
25+
log.debug(f"△ Invalid transfer ID.")
2626
await websocket.close(code=1008, reason="Invalid transfer ID")
2727
return
2828

0 commit comments

Comments
 (0)