11import asyncio
22from starlette .responses import ClientDisconnect
33from starlette .websockets import WebSocketDisconnect
4- from typing import AsyncIterator , Callable , Awaitable , Optional
4+ from typing import AsyncIterator , Callable , Awaitable , Optional , Any
55
66from lib .store import Store
77from lib .metadata import FileMetadata
88from lib .logging import HasLogging , get_logger
9-
109logger = get_logger ('transfer' )
1110
1211
13- class FileTransferError (Exception ):
14- """Base class for file transfer errors."""
15- def __init__ (self , * args , ** kwargs ):
16- self .args = args
17- self .kwargs = kwargs
12+ class TransferError (Exception ):
13+ """Custom exception for transfer errors with optional propagation control."""
14+ def __init__ (self , * args , propagate : bool = False , ** extra : Any ) -> None :
1815 super ().__init__ (* args )
19- for name , value in kwargs .items ():
20- setattr (self , name , value )
21-
22- def __str__ (self ):
23- kwargs_str = ', ' .join (f"{ k } ={ v } " for k , v in self .kwargs .items ())
24- details = f" - { kwargs_str } " if kwargs_str else ''
25- return super ().__str__ () + details
16+ self .propagate = propagate
17+ self .extra = extra
2618
27- def __repr__ (self ):
28- return self .__class__ .__name__ + f"({ ', ' .join (map (repr , self .args ))} , { self .kwargs } )"
19+ @property
20+ def shutdown (self ) -> bool :
21+ """Indicates if the transfer should be shut down (usually the opposite of `propagate`)."""
22+ return self .extra .get ('shutdown' , not self .propagate )
2923
3024
3125class FileTransfer (metaclass = HasLogging , name_from = 'uid' ):
@@ -104,13 +98,13 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
10498 break
10599
106100 if await self .is_interrupted ():
107- raise FileTransferError ("Transfer was interrupted by the receiver." , propagate = False , shutdown = True )
101+ raise TransferError ("Transfer was interrupted by the receiver." , propagate = False )
108102
109103 await self .store .put_in_queue (chunk )
110104 self .bytes_uploaded += len (chunk )
111105
112106 if self .bytes_uploaded < self .file .size :
113- raise FileTransferError ("Received less data than expected." , propagate = True , shutdown = False )
107+ raise TransferError ("Received less data than expected." , propagate = True )
114108
115109 self .debug (f"△ End of upload, sending done marker." )
116110 await self .store .put_in_queue (self .DONE_FLAG )
@@ -123,11 +117,11 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
123117 self .warning (f"△ Timeout during upload." )
124118 await on_error ("Timeout during upload." )
125119
126- except FileTransferError as e :
120+ except TransferError as e :
127121 self .warning (f"△ Upload error: { e } " )
128122 if e .propagate :
129123 await self .store .put_in_queue (self .DEAD_FLAG )
130- if e . shutdown :
124+ else :
131125 await on_error (e )
132126
133127 finally :
@@ -141,10 +135,10 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
141135 chunk = await self .store .get_from_queue ()
142136
143137 if chunk == self .DEAD_FLAG :
144- raise FileTransferError ("Sender disconnected." )
138+ raise TransferError ("Sender disconnected." )
145139
146140 if chunk == self .DONE_FLAG and self .bytes_downloaded < self .file .size :
147- raise FileTransferError ("Received less data than expected." )
141+ raise TransferError ("Received less data than expected." )
148142
149143 elif chunk == self .DONE_FLAG :
150144 self .debug (f"▼ Done marker received, ending download." )
@@ -153,17 +147,13 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
153147 self .bytes_downloaded += len (chunk )
154148 yield chunk
155149
156- except (ClientDisconnect , WebSocketDisconnect ) as e :
157- self .error (f"▼ Unexpected download error: { e } " )
158- await on_error (e )
159-
160- except asyncio .TimeoutError as e :
161- self .warning (f"▼ Timeout during download." )
162- self .debug ("Debug info:" , exc_info = e , stack_info = True )
150+ except Exception as e :
151+ self .error (f"▼ Unexpected download error!" , exc_info = True )
152+ self .debug ("Debug info:" , stack_info = True )
163153 await on_error (e )
164154
165- except FileTransferError as e :
166- self .warning (f"▼ Download error: { e } " )
155+ except TransferError as e :
156+ self .warning (f"▼ Download error" )
167157 await on_error (e )
168158
169159 async def cleanup (self ):
@@ -179,5 +169,10 @@ async def finalize_download(self):
179169 self .warning ("▼ Client disconnected before download was complete." )
180170 await self .set_interrupted ()
181171
182- await asyncio .sleep (2.0 )
172+ await self .cleanup ()
173+ # self.debug("▼ Finalizing download...")
174+ if self .bytes_downloaded < self .file .size and not await self .is_interrupted ():
175+ self .warning ("▼ Client disconnected before download was complete." )
176+ await self .set_interrupted ()
177+
183178 await self .cleanup ()
0 commit comments