11import asyncio
22from starlette .responses import ClientDisconnect
33from starlette .websockets import WebSocketDisconnect
4- from typing import AsyncIterator , Callable , Awaitable , Optional
4+ from typing import Any , AsyncIterator , Callable , Awaitable , Iterable , Optional , Mapping
55
66from lib .store import Store
77from lib .metadata import FileMetadata
8- from lib .logging import HasLogging , get_logger
9-
8+ from lib .logging import HasLogging , get_logger , italic , dim
109logger = get_logger ('transfer' )
1110
1211
13- class FileTransferError (Exception ):
14- """Base class for file transfer errors."""
15- def __init__ (self , * args , ** kwargs ):
16- self .args = args
17- self .kwargs = kwargs
12+ class TransferError (Exception ):
13+ """Custom exception for transfer errors with optional propagation control."""
14+ def __init__ (self , * args , propagate : bool = False , ** extra : Any ) -> None :
1815 super ().__init__ (* args )
19- for name , value in kwargs . items ():
20- setattr ( self , name , value )
16+ self . propagate = propagate
17+ self . extra = extra
2118
22- def __str__ ( self ):
23- kwargs_str = ', ' . join ( f" { k } = { v } " for k , v in self . kwargs . items ())
24- details = f" - { kwargs_str } " if kwargs_str else ''
25- return super (). __str__ () + details
19+ @ property
20+ def shutdown ( self ) -> bool :
21+ """Indicates if the transfer should be shut down (usually the opposite of `propagate`)."""
22+ return self . extra . get ( 'shutdown' , not self . propagate )
2623
27- def __repr__ (self ):
28- return self .__class__ .__name__ + f"({ ', ' .join (map (repr , self .args ))} , { self .kwargs } )"
24+ def __str__ (self ) -> str :
25+ propagate_str = italic ("propagated" ) if self .propagate else dim (italic ("not propagated" ))
26+ extra_str = ', ' .join (f"{ k } ={ v } " for k , v in self .extra .items ()) if self .extra else ''
27+ return f"TransferError: { ', ' .join (self .args )} ({ propagate_str } { ', ' + extra_str if extra_str else '' } )"
2928
3029
3130class FileTransfer (metaclass = HasLogging , name_from = 'uid' ):
@@ -104,13 +103,13 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
104103 break
105104
106105 if await self .is_interrupted ():
107- raise FileTransferError ("Transfer was interrupted by the receiver." , propagate = False , shutdown = True )
106+ raise TransferError ("Transfer was interrupted by the receiver." , propagate = False , shutdown = True )
108107
109108 await self .store .put_in_queue (chunk )
110109 self .bytes_uploaded += len (chunk )
111110
112111 if self .bytes_uploaded < self .file .size :
113- raise FileTransferError ("Received less data than expected." , propagate = True , shutdown = False )
112+ raise TransferError ("Received less data than expected." , propagate = True , shutdown = False )
114113
115114 self .debug (f"△ End of upload, sending done marker." )
116115 await self .store .put_in_queue (self .DONE_FLAG )
@@ -123,11 +122,11 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
123122 self .warning (f"△ Timeout during upload." )
124123 await on_error ("Timeout during upload." )
125124
126- except FileTransferError as e :
125+ except TransferError as e :
127126 self .warning (f"△ Upload error: { e } " )
128127 if e .propagate :
129128 await self .store .put_in_queue (self .DEAD_FLAG )
130- if e . shutdown :
129+ else :
131130 await on_error (e )
132131
133132 finally :
@@ -141,10 +140,10 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
141140 chunk = await self .store .get_from_queue ()
142141
143142 if chunk == self .DEAD_FLAG :
144- raise FileTransferError ("Sender disconnected." )
143+ raise TransferError ("Sender disconnected." )
145144
146145 if chunk == self .DONE_FLAG and self .bytes_downloaded < self .file .size :
147- raise FileTransferError ("Received less data than expected." )
146+ raise TransferError ("Received less data than expected." )
148147
149148 elif chunk == self .DONE_FLAG :
150149 self .debug (f"▼ Done marker received, ending download." )
@@ -153,17 +152,13 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
153152 self .bytes_downloaded += len (chunk )
154153 yield chunk
155154
156- except (ClientDisconnect , WebSocketDisconnect ) as e :
157- self .error (f"▼ Unexpected download error: { e } " )
158- await on_error (e )
159-
160- except asyncio .TimeoutError as e :
161- self .warning (f"▼ Timeout during download." )
162- self .debug ("Debug info:" , exc_info = e , stack_info = True )
155+ except Exception as e :
156+ self .error (f"▼ Unexpected download error!" , exc_info = True )
157+ self .debug ("Debug info:" , stack_info = True )
163158 await on_error (e )
164159
165- except FileTransferError as e :
166- self .warning (f"▼ Download error: { e } " )
160+ except TransferError as e :
161+ self .warning (f"▼ Download error" )
167162 await on_error (e )
168163
169164 async def cleanup (self ):
@@ -179,5 +174,10 @@ async def finalize_download(self):
179174 self .warning ("▼ Client disconnected before download was complete." )
180175 await self .set_interrupted ()
181176
182- await asyncio .sleep (2.0 )
177+ await self .cleanup ()
178+ # self.debug("▼ Finalizing download...")
179+ if self .bytes_downloaded < self .file .size and not await self .is_interrupted ():
180+ self .warning ("▼ Client disconnected before download was complete." )
181+ await self .set_interrupted ()
182+
183183 await self .cleanup ()
0 commit comments