Skip to content

Commit 6ffa88b

Browse files
committed
More verbose GA tests
1 parent 9ed5bff commit 6ffa88b

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

lib/transfer.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,30 @@
11
import asyncio
22
from starlette.responses import ClientDisconnect
33
from starlette.websockets import WebSocketDisconnect
4-
from typing import AsyncIterator, Callable, Awaitable, Optional
4+
from typing import Any, AsyncIterator, Callable, Awaitable, Iterable, Optional, Mapping
55

66
from lib.store import Store
77
from lib.metadata import FileMetadata
8-
from lib.logging import HasLogging, get_logger
9-
8+
from lib.logging import HasLogging, get_logger, italic, dim
109
logger = get_logger('transfer')
1110

1211

13-
class FileTransferError(Exception):
14-
"""Base class for file transfer errors."""
15-
def __init__(self, *args, **kwargs):
16-
self.args = args
17-
self.kwargs = kwargs
12+
class TransferError(Exception):
13+
"""Custom exception for transfer errors with optional propagation control."""
14+
def __init__(self, *args, propagate: bool = False, **extra: Any) -> None:
1815
super().__init__(*args)
19-
for name, value in kwargs.items():
20-
setattr(self, name, value)
16+
self.propagate = propagate
17+
self.extra = extra
2118

22-
def __str__(self):
23-
kwargs_str = ', '.join(f"{k}={v}" for k, v in self.kwargs.items())
24-
details = f" - {kwargs_str}" if kwargs_str else ''
25-
return super().__str__() + details
19+
@property
20+
def shutdown(self) -> bool:
21+
"""Indicates if the transfer should be shut down (usually the opposite of `propagate`)."""
22+
return self.extra.get('shutdown', not self.propagate)
2623

27-
def __repr__(self):
28-
return self.__class__.__name__ + f"({', '.join(map(repr, self.args))}, {self.kwargs})"
24+
def __str__(self) -> str:
25+
propagate_str = italic("propagated") if self.propagate else dim(italic("not propagated"))
26+
extra_str = ', '.join(f"{k}={v}" for k, v in self.extra.items()) if self.extra else ''
27+
return f"TransferError: {', '.join(self.args)} ({propagate_str}{', ' + extra_str if extra_str else ''})"
2928

3029

3130
class FileTransfer(metaclass=HasLogging, name_from='uid'):
@@ -104,13 +103,13 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
104103
break
105104

106105
if await self.is_interrupted():
107-
raise FileTransferError("Transfer was interrupted by the receiver.", propagate=False, shutdown=True)
106+
raise TransferError("Transfer was interrupted by the receiver.", propagate=False, shutdown=True)
108107

109108
await self.store.put_in_queue(chunk)
110109
self.bytes_uploaded += len(chunk)
111110

112111
if self.bytes_uploaded < self.file.size:
113-
raise FileTransferError("Received less data than expected.", propagate=True, shutdown=False)
112+
raise TransferError("Received less data than expected.", propagate=True, shutdown=False)
114113

115114
self.debug(f"△ End of upload, sending done marker.")
116115
await self.store.put_in_queue(self.DONE_FLAG)
@@ -123,11 +122,11 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
123122
self.warning(f"△ Timeout during upload.")
124123
await on_error("Timeout during upload.")
125124

126-
except FileTransferError as e:
125+
except TransferError as e:
127126
self.warning(f"△ Upload error: {e}")
128127
if e.propagate:
129128
await self.store.put_in_queue(self.DEAD_FLAG)
130-
if e.shutdown:
129+
else:
131130
await on_error(e)
132131

133132
finally:
@@ -141,10 +140,10 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
141140
chunk = await self.store.get_from_queue()
142141

143142
if chunk == self.DEAD_FLAG:
144-
raise FileTransferError("Sender disconnected.")
143+
raise TransferError("Sender disconnected.")
145144

146145
if chunk == self.DONE_FLAG and self.bytes_downloaded < self.file.size:
147-
raise FileTransferError("Received less data than expected.")
146+
raise TransferError("Received less data than expected.")
148147

149148
elif chunk == self.DONE_FLAG:
150149
self.debug(f"▼ Done marker received, ending download.")
@@ -153,17 +152,13 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
153152
self.bytes_downloaded += len(chunk)
154153
yield chunk
155154

156-
except (ClientDisconnect, WebSocketDisconnect) as e:
157-
self.error(f"▼ Unexpected download error: {e}")
158-
await on_error(e)
159-
160-
except asyncio.TimeoutError as e:
161-
self.warning(f"▼ Timeout during download.")
162-
self.debug("Debug info:", exc_info=e, stack_info=True)
155+
except Exception as e:
156+
self.error(f"▼ Unexpected download error!", exc_info=True)
157+
self.debug("Debug info:", stack_info=True)
163158
await on_error(e)
164159

165-
except FileTransferError as e:
166-
self.warning(f"▼ Download error: {e}")
160+
except TransferError as e:
161+
self.warning(f"▼ Download error")
167162
await on_error(e)
168163

169164
async def cleanup(self):
@@ -179,5 +174,10 @@ async def finalize_download(self):
179174
self.warning("▼ Client disconnected before download was complete.")
180175
await self.set_interrupted()
181176

182-
await asyncio.sleep(2.0)
177+
await self.cleanup()
178+
# self.debug("▼ Finalizing download...")
179+
if self.bytes_downloaded < self.file.size and not await self.is_interrupted():
180+
self.warning("▼ Client disconnected before download was complete.")
181+
await self.set_interrupted()
182+
183183
await self.cleanup()

0 commit comments

Comments
 (0)