1414
1515from __future__ import annotations
1616import asyncio
17- from typing import List , Optional , Tuple
17+ import google_crc32c
18+ from google .api_core import exceptions
19+ from google .api_core .retry_async import AsyncRetry
20+
21+ from typing import List , Optional , Tuple , Any , Dict
1822
1923from google_crc32c import Checksum
2024
2529from google .cloud .storage ._experimental .asyncio .async_grpc_client import (
2630 AsyncGrpcClient ,
2731)
32+ from google .cloud .storage ._experimental .asyncio .retry .bidi_stream_retry_manager import (
33+ _BidiStreamRetryManager ,
34+ )
35+ from google .cloud .storage ._experimental .asyncio .retry .reads_resumption_strategy import (
36+ _ReadResumptionStrategy ,
37+ _DownloadState ,
38+ )
2839
2940from io import BytesIO
3041from google .cloud import _storage_v2
31- from google .cloud .storage .exceptions import DataCorruption
3242from google .cloud .storage ._helpers import generate_random_56_bit_integer
3343
3444
3545_MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100
3646
3747
38- class Result :
39- """An instance of this class will be populated and retured for each
40- `read_range` provided to ``download_ranges`` method.
41-
42- """
43-
44- def __init__ (self , bytes_requested : int ):
45- # only while instantiation, should not be edited later.
46- # hence there's no setter, only getter is provided.
47- self ._bytes_requested : int = bytes_requested
48- self ._bytes_written : int = 0
49-
50- @property
51- def bytes_requested (self ) -> int :
52- return self ._bytes_requested
53-
54- @property
55- def bytes_written (self ) -> int :
56- return self ._bytes_written
57-
58- @bytes_written .setter
59- def bytes_written (self , value : int ):
60- self ._bytes_written = value
61-
62- def __repr__ (self ):
63- return f"bytes_requested: { self ._bytes_requested } , bytes_written: { self ._bytes_written } "
64-
65-
6648class AsyncMultiRangeDownloader :
6749 """Provides an interface for downloading multiple ranges of a GCS ``Object``
6850 concurrently.
@@ -103,6 +85,7 @@ async def create_mrd(
10385 object_name : str ,
10486 generation_number : Optional [int ] = None ,
10587 read_handle : Optional [bytes ] = None ,
88+ retry_policy : Optional [AsyncRetry ] = None ,
10689 ) -> AsyncMultiRangeDownloader :
10790 """Initializes a MultiRangeDownloader and opens the underlying bidi-gRPC
10891 object for reading.
@@ -124,11 +107,14 @@ async def create_mrd(
124107 :param read_handle: (Optional) An existing handle for reading the object.
125108 If provided, opening the bidi-gRPC connection will be faster.
126109
110+ :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
111+ :param retry_policy: (Optional) The retry policy to use for the ``open`` operation.
112+
127113 :rtype: :class:`~google.cloud.storage._experimental.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader`
128114 :returns: An initialized AsyncMultiRangeDownloader instance for reading.
129115 """
130116 mrd = cls (client , bucket_name , object_name , generation_number , read_handle )
131- await mrd .open ()
117+ await mrd .open (retry_policy = retry_policy )
132118 return mrd
133119
134120 def __init__ (
@@ -174,7 +160,7 @@ def __init__(
174160 self ._download_ranges_id_to_pending_read_ids = {}
175161 self .persisted_size : Optional [int ] = None # updated after opening the stream
176162
177- async def open (self ) -> None :
163+ async def open (self , retry_policy : Optional [ AsyncRetry ] = None ) -> None :
178164 """Opens the bidi-gRPC connection to read from the object.
179165
180166 This method initializes and opens an `_AsyncReadObjectStream` (bidi-gRPC stream) to
@@ -185,27 +171,43 @@ async def open(self) -> None:
185171 """
186172 if self ._is_stream_open :
187173 raise ValueError ("Underlying bidi-gRPC stream is already open" )
188- self .read_obj_str = _AsyncReadObjectStream (
189- client = self .client ,
190- bucket_name = self .bucket_name ,
191- object_name = self .object_name ,
192- generation_number = self .generation_number ,
193- read_handle = self .read_handle ,
194- )
195- await self .read_obj_str .open ()
196- self ._is_stream_open = True
197- if self .generation_number is None :
198- self .generation_number = self .read_obj_str .generation_number
199- self .read_handle = self .read_obj_str .read_handle
200- if self .read_obj_str .persisted_size is not None :
201- self .persisted_size = self .read_obj_str .persisted_size
202- return
174+
175+ if retry_policy is None :
176+ # Default policy: retry generic transient errors
177+ retry_policy = AsyncRetry (
178+ predicate = lambda e : isinstance (e , (exceptions .ServiceUnavailable , exceptions .DeadlineExceeded ))
179+ )
180+
181+ async def _do_open ():
182+ self .read_obj_str = _AsyncReadObjectStream (
183+ client = self .client ,
184+ bucket_name = self .bucket_name ,
185+ object_name = self .object_name ,
186+ generation_number = self .generation_number ,
187+ read_handle = self .read_handle ,
188+ )
189+ await self .read_obj_str .open ()
190+
191+ if self .read_obj_str .generation_number :
192+ self .generation_number = self .read_obj_str .generation_number
193+ if self .read_obj_str .read_handle :
194+ self .read_handle = self .read_obj_str .read_handle
195+ if self .read_obj_str .persisted_size is not None :
196+ self .persisted_size = self .read_obj_str .persisted_size
197+
198+ self ._is_stream_open = True
199+
200+ # Execute open with retry policy
201+ await retry_policy (_do_open )()
203202
204203 async def download_ranges (
205- self , read_ranges : List [Tuple [int , int , BytesIO ]], lock : asyncio .Lock = None
204+ self ,
205+ read_ranges : List [Tuple [int , int , BytesIO ]],
206+ lock : asyncio .Lock = None ,
207+ retry_policy : AsyncRetry = None
206208 ) -> None :
207209 """Downloads multiple byte ranges from the object into the buffers
208- provided by user.
210+ provided by user with automatic retries .
209211
210212 :type read_ranges: List[Tuple[int, int, "BytesIO"]]
211213 :param read_ranges: A list of tuples, where each tuple represents a
@@ -240,6 +242,8 @@ async def download_ranges(
240242
241243 ```
242244
245+ :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
246+ :param retry_policy: (Optional) The retry policy to use for the operation.
243247
244248 :raises ValueError: if the underlying bidi-GRPC stream is not open.
245249 :raises ValueError: if the length of read_ranges is more than 1000.
@@ -258,80 +262,101 @@ async def download_ranges(
258262 if lock is None :
259263 lock = asyncio .Lock ()
260264
261- _func_id = generate_random_56_bit_integer ()
262- read_ids_in_current_func = set ()
263- for i in range (0 , len (read_ranges ), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST ):
264- read_ranges_segment = read_ranges [
265- i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
266- ]
265+ if retry_policy is None :
266+ retry_policy = AsyncRetry (
267+ predicate = lambda e : isinstance (e , (exceptions .ServiceUnavailable , exceptions .DeadlineExceeded ))
268+ )
269+
270+ # Initialize Global State for Retry Strategy
271+ download_states = {}
272+ for read_range in read_ranges :
273+ read_id = generate_random_56_bit_integer ()
274+ download_states [read_id ] = _DownloadState (
275+ initial_offset = read_range [0 ],
276+ initial_length = read_range [1 ],
277+ user_buffer = read_range [2 ]
278+ )
279+
280+ initial_state = {
281+ "download_states" : download_states ,
282+ "read_handle" : self .read_handle ,
283+ "routing_token" : None
284+ }
285+
286+ # Track attempts to manage stream reuse
287+ is_first_attempt = True
288+
289+ def stream_opener (requests : List [_storage_v2 .ReadRange ], state : Dict [str , Any ]):
290+
291+ async def generator ():
292+ nonlocal is_first_attempt
293+
294+ async with lock :
295+ current_handle = state .get ("read_handle" )
296+ current_token = state .get ("routing_token" )
297+
298+ # We reopen if it's a redirect (token exists) OR if this is a retry
299+ # (not first attempt). This prevents trying to send data on a dead
300+ # stream from a previous failed attempt.
301+ should_reopen = (not is_first_attempt ) or (current_token is not None )
267302
268- read_ranges_for_bidi_req = []
269- for j , read_range in enumerate (read_ranges_segment ):
270- read_id = generate_random_56_bit_integer ()
271- read_ids_in_current_func .add (read_id )
272- self ._read_id_to_download_ranges_id [read_id ] = _func_id
273- self ._read_id_to_writable_buffer_dict [read_id ] = read_range [2 ]
274- bytes_requested = read_range [1 ]
275- read_ranges_for_bidi_req .append (
276- _storage_v2 .ReadRange (
277- read_offset = read_range [0 ],
278- read_length = bytes_requested ,
279- read_id = read_id ,
280- )
281- )
282- async with lock :
283- await self .read_obj_str .send (
284- _storage_v2 .BidiReadObjectRequest (
285- read_ranges = read_ranges_for_bidi_req
286- )
287- )
288- self ._download_ranges_id_to_pending_read_ids [
289- _func_id
290- ] = read_ids_in_current_func
291-
292- while len (self ._download_ranges_id_to_pending_read_ids [_func_id ]) > 0 :
293- async with lock :
294- response = await self .read_obj_str .recv ()
295-
296- if response is None :
297- raise Exception ("None response received, something went wrong." )
298-
299- for object_data_range in response .object_data_ranges :
300- if object_data_range .read_range is None :
301- raise Exception ("Invalid response, read_range is None" )
302-
303- checksummed_data = object_data_range .checksummed_data
304- data = checksummed_data .content
305- server_checksum = checksummed_data .crc32c
306-
307- client_crc32c = Checksum (data ).digest ()
308- client_checksum = int .from_bytes (client_crc32c , "big" )
309-
310- if server_checksum != client_checksum :
311- raise DataCorruption (
312- response ,
313- f"Checksum mismatch for read_id { object_data_range .read_range .read_id } . "
314- f"Server sent { server_checksum } , client calculated { client_checksum } ." ,
315- )
316-
317- read_id = object_data_range .read_range .read_id
318- buffer = self ._read_id_to_writable_buffer_dict [read_id ]
319- buffer .write (data )
320-
321- if object_data_range .range_end :
322- tmp_dn_ranges_id = self ._read_id_to_download_ranges_id [read_id ]
323- self ._download_ranges_id_to_pending_read_ids [
324- tmp_dn_ranges_id
325- ].remove (read_id )
326- del self ._read_id_to_download_ranges_id [read_id ]
303+ if should_reopen :
304+ # Close existing stream if any
305+ if self .read_obj_str :
306+ await self .read_obj_str .close ()
307+
308+ # Re-initialize stream
309+ self .read_obj_str = _AsyncReadObjectStream (
310+ client = self .client ,
311+ bucket_name = self .bucket_name ,
312+ object_name = self .object_name ,
313+ generation_number = self .generation_number ,
314+ read_handle = current_handle ,
315+ )
316+
317+ # Inject routing_token into metadata if present
318+ metadata = []
319+ if current_token :
320+ metadata .append (("x-goog-request-params" , f"routing_token={ current_token } " ))
321+
322+ await self .read_obj_str .open (metadata = metadata if metadata else None )
323+ self ._is_stream_open = True
324+
325+ # Mark first attempt as done; next time this runs it will be a retry
326+ is_first_attempt = False
327+
328+ # Send Requests
329+ for i in range (0 , len (requests ), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST ):
330+ batch = requests [i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST ]
331+ await self .read_obj_str .send (
332+ _storage_v2 .BidiReadObjectRequest (read_ranges = batch )
333+ )
334+
335+ while True :
336+ response = await self .read_obj_str .recv ()
337+ if response is None :
338+ break
339+ yield response
340+
341+ return generator ()
342+
343+ strategy = _ReadResumptionStrategy ()
344+ retry_manager = _BidiStreamRetryManager (strategy , stream_opener )
345+
346+ await retry_manager .execute (initial_state , retry_policy )
347+
348+ if initial_state .get ("read_handle" ):
349+ self .read_handle = initial_state ["read_handle" ]
327350
328351 async def close (self ):
329352 """
330353 Closes the underlying bidi-gRPC connection.
331354 """
332355 if not self ._is_stream_open :
333356 raise ValueError ("Underlying bidi-gRPC stream is not open" )
334- await self .read_obj_str .close ()
357+
358+ if self .read_obj_str :
359+ await self .read_obj_str .close ()
335360 self .read_obj_str = None
336361 self ._is_stream_open = False
337362
0 commit comments