Skip to content

Commit 722a5ed

Browse files
committed
integrate retry logic with the MRD
1 parent 1b680e8 commit 722a5ed

File tree

4 files changed

+432
-281
lines changed

4 files changed

+432
-281
lines changed

google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py

Lines changed: 140 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414

1515
from __future__ import annotations
1616
import asyncio
17-
from typing import List, Optional, Tuple
17+
import google_crc32c
18+
from google.api_core import exceptions
19+
from google.api_core.retry_async import AsyncRetry
20+
21+
from typing import List, Optional, Tuple, Any, Dict
1822

1923
from google_crc32c import Checksum
2024

@@ -25,44 +29,22 @@
2529
from google.cloud.storage._experimental.asyncio.async_grpc_client import (
2630
AsyncGrpcClient,
2731
)
32+
from google.cloud.storage._experimental.asyncio.retry.bidi_stream_retry_manager import (
33+
_BidiStreamRetryManager,
34+
)
35+
from google.cloud.storage._experimental.asyncio.retry.reads_resumption_strategy import (
36+
_ReadResumptionStrategy,
37+
_DownloadState,
38+
)
2839

2940
from io import BytesIO
3041
from google.cloud import _storage_v2
31-
from google.cloud.storage.exceptions import DataCorruption
3242
from google.cloud.storage._helpers import generate_random_56_bit_integer
3343

3444

3545
_MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100
3646

3747

38-
class Result:
39-
"""An instance of this class will be populated and retured for each
40-
`read_range` provided to ``download_ranges`` method.
41-
42-
"""
43-
44-
def __init__(self, bytes_requested: int):
45-
# only while instantiation, should not be edited later.
46-
# hence there's no setter, only getter is provided.
47-
self._bytes_requested: int = bytes_requested
48-
self._bytes_written: int = 0
49-
50-
@property
51-
def bytes_requested(self) -> int:
52-
return self._bytes_requested
53-
54-
@property
55-
def bytes_written(self) -> int:
56-
return self._bytes_written
57-
58-
@bytes_written.setter
59-
def bytes_written(self, value: int):
60-
self._bytes_written = value
61-
62-
def __repr__(self):
63-
return f"bytes_requested: {self._bytes_requested}, bytes_written: {self._bytes_written}"
64-
65-
6648
class AsyncMultiRangeDownloader:
6749
"""Provides an interface for downloading multiple ranges of a GCS ``Object``
6850
concurrently.
@@ -103,6 +85,7 @@ async def create_mrd(
10385
object_name: str,
10486
generation_number: Optional[int] = None,
10587
read_handle: Optional[bytes] = None,
88+
retry_policy: Optional[AsyncRetry] = None,
10689
) -> AsyncMultiRangeDownloader:
10790
"""Initializes a MultiRangeDownloader and opens the underlying bidi-gRPC
10891
object for reading.
@@ -124,11 +107,14 @@ async def create_mrd(
124107
:param read_handle: (Optional) An existing handle for reading the object.
125108
If provided, opening the bidi-gRPC connection will be faster.
126109
110+
:type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
111+
:param retry_policy: (Optional) The retry policy to use for the ``open`` operation.
112+
127113
:rtype: :class:`~google.cloud.storage._experimental.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader`
128114
:returns: An initialized AsyncMultiRangeDownloader instance for reading.
129115
"""
130116
mrd = cls(client, bucket_name, object_name, generation_number, read_handle)
131-
await mrd.open()
117+
await mrd.open(retry_policy=retry_policy)
132118
return mrd
133119

134120
def __init__(
@@ -174,7 +160,7 @@ def __init__(
174160
self._download_ranges_id_to_pending_read_ids = {}
175161
self.persisted_size: Optional[int] = None # updated after opening the stream
176162

177-
async def open(self) -> None:
163+
async def open(self, retry_policy: Optional[AsyncRetry] = None) -> None:
178164
"""Opens the bidi-gRPC connection to read from the object.
179165
180166
This method initializes and opens an `_AsyncReadObjectStream` (bidi-gRPC stream) to
@@ -185,27 +171,43 @@ async def open(self) -> None:
185171
"""
186172
if self._is_stream_open:
187173
raise ValueError("Underlying bidi-gRPC stream is already open")
188-
self.read_obj_str = _AsyncReadObjectStream(
189-
client=self.client,
190-
bucket_name=self.bucket_name,
191-
object_name=self.object_name,
192-
generation_number=self.generation_number,
193-
read_handle=self.read_handle,
194-
)
195-
await self.read_obj_str.open()
196-
self._is_stream_open = True
197-
if self.generation_number is None:
198-
self.generation_number = self.read_obj_str.generation_number
199-
self.read_handle = self.read_obj_str.read_handle
200-
if self.read_obj_str.persisted_size is not None:
201-
self.persisted_size = self.read_obj_str.persisted_size
202-
return
174+
175+
if retry_policy is None:
176+
# Default policy: retry generic transient errors
177+
retry_policy = AsyncRetry(
178+
predicate=lambda e: isinstance(e, (exceptions.ServiceUnavailable, exceptions.DeadlineExceeded))
179+
)
180+
181+
async def _do_open():
182+
self.read_obj_str = _AsyncReadObjectStream(
183+
client=self.client,
184+
bucket_name=self.bucket_name,
185+
object_name=self.object_name,
186+
generation_number=self.generation_number,
187+
read_handle=self.read_handle,
188+
)
189+
await self.read_obj_str.open()
190+
191+
if self.read_obj_str.generation_number:
192+
self.generation_number = self.read_obj_str.generation_number
193+
if self.read_obj_str.read_handle:
194+
self.read_handle = self.read_obj_str.read_handle
195+
if self.read_obj_str.persisted_size is not None:
196+
self.persisted_size = self.read_obj_str.persisted_size
197+
198+
self._is_stream_open = True
199+
200+
# Execute open with retry policy
201+
await retry_policy(_do_open)()
203202

204203
async def download_ranges(
205-
self, read_ranges: List[Tuple[int, int, BytesIO]], lock: asyncio.Lock = None
204+
self,
205+
read_ranges: List[Tuple[int, int, BytesIO]],
206+
lock: asyncio.Lock = None,
207+
retry_policy: AsyncRetry = None
206208
) -> None:
207209
"""Downloads multiple byte ranges from the object into the buffers
208-
provided by user.
210+
provided by user with automatic retries.
209211
210212
:type read_ranges: List[Tuple[int, int, "BytesIO"]]
211213
:param read_ranges: A list of tuples, where each tuple represents a
@@ -240,6 +242,8 @@ async def download_ranges(
240242
241243
```
242244
245+
:type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
246+
:param retry_policy: (Optional) The retry policy to use for the operation.
243247
244248
:raises ValueError: if the underlying bidi-GRPC stream is not open.
245249
:raises ValueError: if the length of read_ranges is more than 1000.
@@ -258,80 +262,101 @@ async def download_ranges(
258262
if lock is None:
259263
lock = asyncio.Lock()
260264

261-
_func_id = generate_random_56_bit_integer()
262-
read_ids_in_current_func = set()
263-
for i in range(0, len(read_ranges), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST):
264-
read_ranges_segment = read_ranges[
265-
i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
266-
]
265+
if retry_policy is None:
266+
retry_policy = AsyncRetry(
267+
predicate=lambda e: isinstance(e, (exceptions.ServiceUnavailable, exceptions.DeadlineExceeded))
268+
)
269+
270+
# Initialize Global State for Retry Strategy
271+
download_states = {}
272+
for read_range in read_ranges:
273+
read_id = generate_random_56_bit_integer()
274+
download_states[read_id] = _DownloadState(
275+
initial_offset=read_range[0],
276+
initial_length=read_range[1],
277+
user_buffer=read_range[2]
278+
)
279+
280+
initial_state = {
281+
"download_states": download_states,
282+
"read_handle": self.read_handle,
283+
"routing_token": None
284+
}
285+
286+
# Track attempts to manage stream reuse
287+
is_first_attempt = True
288+
289+
def stream_opener(requests: List[_storage_v2.ReadRange], state: Dict[str, Any]):
290+
291+
async def generator():
292+
nonlocal is_first_attempt
293+
294+
async with lock:
295+
current_handle = state.get("read_handle")
296+
current_token = state.get("routing_token")
297+
298+
# We reopen if it's a redirect (token exists) OR if this is a retry
299+
# (not first attempt). This prevents trying to send data on a dead
300+
# stream from a previous failed attempt.
301+
should_reopen = (not is_first_attempt) or (current_token is not None)
267302

268-
read_ranges_for_bidi_req = []
269-
for j, read_range in enumerate(read_ranges_segment):
270-
read_id = generate_random_56_bit_integer()
271-
read_ids_in_current_func.add(read_id)
272-
self._read_id_to_download_ranges_id[read_id] = _func_id
273-
self._read_id_to_writable_buffer_dict[read_id] = read_range[2]
274-
bytes_requested = read_range[1]
275-
read_ranges_for_bidi_req.append(
276-
_storage_v2.ReadRange(
277-
read_offset=read_range[0],
278-
read_length=bytes_requested,
279-
read_id=read_id,
280-
)
281-
)
282-
async with lock:
283-
await self.read_obj_str.send(
284-
_storage_v2.BidiReadObjectRequest(
285-
read_ranges=read_ranges_for_bidi_req
286-
)
287-
)
288-
self._download_ranges_id_to_pending_read_ids[
289-
_func_id
290-
] = read_ids_in_current_func
291-
292-
while len(self._download_ranges_id_to_pending_read_ids[_func_id]) > 0:
293-
async with lock:
294-
response = await self.read_obj_str.recv()
295-
296-
if response is None:
297-
raise Exception("None response received, something went wrong.")
298-
299-
for object_data_range in response.object_data_ranges:
300-
if object_data_range.read_range is None:
301-
raise Exception("Invalid response, read_range is None")
302-
303-
checksummed_data = object_data_range.checksummed_data
304-
data = checksummed_data.content
305-
server_checksum = checksummed_data.crc32c
306-
307-
client_crc32c = Checksum(data).digest()
308-
client_checksum = int.from_bytes(client_crc32c, "big")
309-
310-
if server_checksum != client_checksum:
311-
raise DataCorruption(
312-
response,
313-
f"Checksum mismatch for read_id {object_data_range.read_range.read_id}. "
314-
f"Server sent {server_checksum}, client calculated {client_checksum}.",
315-
)
316-
317-
read_id = object_data_range.read_range.read_id
318-
buffer = self._read_id_to_writable_buffer_dict[read_id]
319-
buffer.write(data)
320-
321-
if object_data_range.range_end:
322-
tmp_dn_ranges_id = self._read_id_to_download_ranges_id[read_id]
323-
self._download_ranges_id_to_pending_read_ids[
324-
tmp_dn_ranges_id
325-
].remove(read_id)
326-
del self._read_id_to_download_ranges_id[read_id]
303+
if should_reopen:
304+
# Close existing stream if any
305+
if self.read_obj_str:
306+
await self.read_obj_str.close()
307+
308+
# Re-initialize stream
309+
self.read_obj_str = _AsyncReadObjectStream(
310+
client=self.client,
311+
bucket_name=self.bucket_name,
312+
object_name=self.object_name,
313+
generation_number=self.generation_number,
314+
read_handle=current_handle,
315+
)
316+
317+
# Inject routing_token into metadata if present
318+
metadata = []
319+
if current_token:
320+
metadata.append(("x-goog-request-params", f"routing_token={current_token}"))
321+
322+
await self.read_obj_str.open(metadata=metadata if metadata else None)
323+
self._is_stream_open = True
324+
325+
# Mark first attempt as done; next time this runs it will be a retry
326+
is_first_attempt = False
327+
328+
# Send Requests
329+
for i in range(0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST):
330+
batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST]
331+
await self.read_obj_str.send(
332+
_storage_v2.BidiReadObjectRequest(read_ranges=batch)
333+
)
334+
335+
while True:
336+
response = await self.read_obj_str.recv()
337+
if response is None:
338+
break
339+
yield response
340+
341+
return generator()
342+
343+
strategy = _ReadResumptionStrategy()
344+
retry_manager = _BidiStreamRetryManager(strategy, stream_opener)
345+
346+
await retry_manager.execute(initial_state, retry_policy)
347+
348+
if initial_state.get("read_handle"):
349+
self.read_handle = initial_state["read_handle"]
327350

328351
async def close(self):
329352
"""
330353
Closes the underlying bidi-gRPC connection.
331354
"""
332355
if not self._is_stream_open:
333356
raise ValueError("Underlying bidi-gRPC stream is not open")
334-
await self.read_obj_str.close()
357+
358+
if self.read_obj_str:
359+
await self.read_obj_str.close()
335360
self.read_obj_str = None
336361
self._is_stream_open = False
337362

0 commit comments

Comments
 (0)