11import asyncio
22from starlette .responses import ClientDisconnect
33from starlette .websockets import WebSocketDisconnect
4- from typing import AsyncIterator , Callable , Awaitable
4+ from typing import AsyncIterator , Callable , Awaitable , Optional , Any
55
66from lib .store import Store
7- from lib .logging import get_logger
87from lib .metadata import FileMetadata
8+ from lib .logging import HasLogging , get_logger
9+ logger = get_logger ('transfer' )
910
1011
11- class FileTransfer :
12+ class TransferError (Exception ):
13+ """Custom exception for transfer errors with optional propagation control."""
14+ def __init__ (self , * args , propagate : bool = False , ** extra : Any ) -> None :
15+ super ().__init__ (* args )
16+ self .propagate = propagate
17+ self .extra = extra
18+
19+ @property
20+ def shutdown (self ) -> bool :
21+ """Indicates if the transfer should be shut down (usually the opposite of `propagate`)."""
22+ return self .extra .get ('shutdown' , not self .propagate )
23+
24+
25+ class FileTransfer (metaclass = HasLogging , name_from = 'uid' ):
26+ """Handles file transfers, including metadata queries and data streaming."""
1227
1328 DONE_FLAG = b'\x00 \xFF '
1429 DEAD_FLAG = b'\xDE \xAD '
@@ -20,9 +35,6 @@ def __init__(self, uid: str, file: FileMetadata):
2035 self .bytes_uploaded = 0
2136 self .bytes_downloaded = 0
2237
23- log = get_logger (self .uid )
24- self .debug , self .info , self .warning , self .error = log .debug , log .info , log .warning , log .error
25-
2638 @classmethod
2739 async def create (cls , uid : str , file : FileMetadata ):
2840 transfer = cls (uid , file )
@@ -86,27 +98,33 @@ async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[
8698 break
8799
88100 if await self .is_interrupted ():
89- raise ClientDisconnect ("Transfer was interrupted by the receiver." )
101+ raise TransferError ("Transfer was interrupted by the receiver." , propagate = False )
90102
91103 await self .store .put_in_queue (chunk )
92104 self .bytes_uploaded += len (chunk )
93105
94106 if self .bytes_uploaded < self .file .size :
95- raise ClientDisconnect ("Received less data than expected." )
107+ raise TransferError ("Received less data than expected." , propagate = True )
96108
97109 self .debug (f"△ End of upload, sending done marker." )
98110 await self .store .put_in_queue (self .DONE_FLAG )
99111
100112 except (ClientDisconnect , WebSocketDisconnect ) as e :
101- self .warning (f"△ Upload error: { str ( e ) } " )
113+ self .error (f"△ Unexpected upload error: { e } " )
102114 await self .store .put_in_queue (self .DEAD_FLAG )
103- await on_error (e )
104115
105116 except asyncio .TimeoutError as e :
106117 self .warning (f"△ Timeout during upload." )
107118 await on_error ("Timeout during upload." )
108119
109- else :
120+ except TransferError as e :
121+ self .warning (f"△ Upload error: { e } " )
122+ if e .propagate :
123+ await self .store .put_in_queue (self .DEAD_FLAG )
124+ else :
125+ await on_error (e )
126+
127+ finally :
110128 await asyncio .sleep (1.0 )
111129
112130 async def supply_download (self , on_error : Callable [[Exception | str ], Awaitable [None ]]) -> AsyncIterator [bytes ]:
@@ -117,10 +135,10 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
117135 chunk = await self .store .get_from_queue ()
118136
119137 if chunk == self .DEAD_FLAG :
120- raise ClientDisconnect ("Sender disconnected." )
138+ raise TransferError ("Sender disconnected." )
121139
122140 if chunk == self .DONE_FLAG and self .bytes_downloaded < self .file .size :
123- raise ClientDisconnect ("Received less data than expected." )
141+ raise TransferError ("Received less data than expected." )
124142
125143 elif chunk == self .DONE_FLAG :
126144 self .debug (f"▼ Done marker received, ending download." )
@@ -129,16 +147,14 @@ async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[
129147 self .bytes_downloaded += len (chunk )
130148 yield chunk
131149
132- except (ClientDisconnect , WebSocketDisconnect ) as e :
133- self .warning (f"▼ Download error: { e } " )
134- await self .set_interrupted ()
135-
136- except asyncio .TimeoutError :
137- self .warning (f"▼ Timeout during download." )
138- await on_error ("Timeout during download." )
150+ except Exception as e :
151+ self .error (f"▼ Unexpected download error!" , exc_info = True )
152+ self .debug ("Debug info:" , stack_info = True )
153+ await on_error (e )
139154
140- else :
141- await asyncio .sleep (1.0 )
155+ except TransferError as e :
156+ self .warning (f"▼ Download error" )
157+ await on_error (e )
142158
143159 async def cleanup (self ):
144160 try :
@@ -148,10 +164,15 @@ async def cleanup(self):
148164 pass
149165
150166 async def finalize_download (self ):
151- self .debug ("▼ Finalizing download..." )
167+ # self.debug("▼ Finalizing download...")
168+ if self .bytes_downloaded < self .file .size and not await self .is_interrupted ():
169+ self .warning ("▼ Client disconnected before download was complete." )
170+ await self .set_interrupted ()
171+
172+ await self .cleanup ()
173+ # self.debug("▼ Finalizing download...")
152174 if self .bytes_downloaded < self .file .size and not await self .is_interrupted ():
153175 self .warning ("▼ Client disconnected before download was complete." )
154176 await self .set_interrupted ()
155177
156- await asyncio .sleep (4.0 )
157178 await self .cleanup ()
0 commit comments