Skip to content

Commit a57ea0e

Browse files
authored
feat(experimental): add write resumption strategy (#1663)
Adding writes resumption strategy which will be used for error handling of bidi writes operation.
1 parent 0c35d3f commit a57ea0e

File tree

3 files changed

+446
-4
lines changed

3 files changed

+446
-4
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict, IO, Iterable, Optional, Union
16+
17+
import google_crc32c
18+
from google.cloud._storage_v2.types import storage as storage_type
19+
from google.cloud._storage_v2.types.storage import BidiWriteObjectRedirectedError
20+
from google.cloud.storage._experimental.asyncio.retry.base_strategy import (
21+
_BaseResumptionStrategy,
22+
)
23+
24+
25+
class _WriteState:
26+
"""A helper class to track the state of a single upload operation.
27+
28+
:type spec: :class:`google.cloud.storage_v2.types.AppendObjectSpec`
29+
:param spec: The specification for the object to write.
30+
31+
:type chunk_size: int
32+
:param chunk_size: The size of chunks to write to the server.
33+
34+
:type user_buffer: IO[bytes]
35+
:param user_buffer: The data source.
36+
"""
37+
38+
def __init__(
39+
self,
40+
spec: Union[storage_type.AppendObjectSpec, storage_type.WriteObjectSpec],
41+
chunk_size: int,
42+
user_buffer: IO[bytes],
43+
):
44+
self.spec = spec
45+
self.chunk_size = chunk_size
46+
self.user_buffer = user_buffer
47+
self.persisted_size: int = 0
48+
self.bytes_sent: int = 0
49+
self.write_handle: Union[bytes, storage_type.BidiWriteHandle, None] = None
50+
self.routing_token: Optional[str] = None
51+
self.is_finalized: bool = False
52+
53+
54+
class _WriteResumptionStrategy(_BaseResumptionStrategy):
55+
"""The concrete resumption strategy for bidi writes."""
56+
57+
def generate_requests(
58+
self, state: Dict[str, Any]
59+
) -> Iterable[storage_type.BidiWriteObjectRequest]:
60+
"""Generates BidiWriteObjectRequests to resume or continue the upload.
61+
62+
For Appendable Objects, every stream opening should send an
63+
AppendObjectSpec. If resuming, the `write_handle` is added to that spec.
64+
65+
This method is not applicable for `open` methods.
66+
"""
67+
write_state: _WriteState = state["write_state"]
68+
69+
initial_request = storage_type.BidiWriteObjectRequest()
70+
71+
# Determine if we need to send WriteObjectSpec or AppendObjectSpec
72+
if isinstance(write_state.spec, storage_type.WriteObjectSpec):
73+
initial_request.write_object_spec = write_state.spec
74+
else:
75+
if write_state.write_handle:
76+
write_state.spec.write_handle = write_state.write_handle
77+
78+
if write_state.routing_token:
79+
write_state.spec.routing_token = write_state.routing_token
80+
initial_request.append_object_spec = write_state.spec
81+
82+
yield initial_request
83+
84+
# The buffer should already be seeked to the correct position (persisted_size)
85+
# by the `recover_state_on_failure` method before this is called.
86+
while not write_state.is_finalized:
87+
chunk = write_state.user_buffer.read(write_state.chunk_size)
88+
89+
# End of File detection
90+
if not chunk:
91+
return
92+
93+
checksummed_data = storage_type.ChecksummedData(content=chunk)
94+
checksum = google_crc32c.Checksum(chunk)
95+
checksummed_data.crc32c = int.from_bytes(checksum.digest(), "big")
96+
97+
request = storage_type.BidiWriteObjectRequest(
98+
write_offset=write_state.bytes_sent,
99+
checksummed_data=checksummed_data,
100+
)
101+
write_state.bytes_sent += len(chunk)
102+
103+
yield request
104+
105+
def update_state_from_response(
106+
self, response: storage_type.BidiWriteObjectResponse, state: Dict[str, Any]
107+
) -> None:
108+
"""Processes a server response and updates the write state."""
109+
write_state: _WriteState = state["write_state"]
110+
111+
if response.persisted_size:
112+
write_state.persisted_size = response.persisted_size
113+
114+
if response.write_handle:
115+
write_state.write_handle = response.write_handle
116+
117+
if response.resource:
118+
write_state.persisted_size = response.resource.size
119+
if response.resource.finalize_time:
120+
write_state.is_finalized = True
121+
122+
async def recover_state_on_failure(
123+
self, error: Exception, state: Dict[str, Any]
124+
) -> None:
125+
"""
126+
Handles errors, specifically BidiWriteObjectRedirectedError, and rewinds state.
127+
128+
This method rewinds the user buffer and internal byte tracking to the
129+
last confirmed 'persisted_size' from the server.
130+
"""
131+
write_state: _WriteState = state["write_state"]
132+
cause = getattr(error, "cause", error)
133+
134+
# Extract routing token and potentially a new write handle for redirection.
135+
if isinstance(cause, BidiWriteObjectRedirectedError):
136+
if cause.routing_token:
137+
write_state.routing_token = cause.routing_token
138+
139+
redirect_handle = getattr(cause, "write_handle", None)
140+
if redirect_handle:
141+
write_state.write_handle = redirect_handle
142+
143+
# We must assume any data sent beyond 'persisted_size' was lost.
144+
# Reset the user buffer to the last known good byte confirmed by the server.
145+
write_state.user_buffer.seek(write_state.persisted_size)
146+
write_state.bytes_sent = write_state.persisted_size

tests/unit/asyncio/retry/test_reads_resumption_strategy.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
from google.cloud._storage_v2.types.storage import BidiReadObjectRedirectedError
2828

2929
_READ_ID = 1
30-
LOGGER_NAME = "google.cloud.storage._experimental.asyncio.retry.reads_resumption_strategy"
30+
LOGGER_NAME = (
31+
"google.cloud.storage._experimental.asyncio.retry.reads_resumption_strategy"
32+
)
3133

3234

3335
class TestDownloadState(unittest.TestCase):
@@ -309,7 +311,9 @@ def test_update_state_missing_read_range_logs_warning(self):
309311
with self.assertLogs(LOGGER_NAME, level="WARNING") as cm:
310312
self.strategy.update_state_from_response(response, self.state)
311313

312-
self.assertTrue(any("missing read_range field" in output for output in cm.output))
314+
self.assertTrue(
315+
any("missing read_range field" in output for output in cm.output)
316+
)
313317

314318
def test_update_state_unknown_id_logs_warning(self):
315319
"""Verify we log a warning and continue when read_id is unknown."""
@@ -320,8 +324,12 @@ def test_update_state_unknown_id_logs_warning(self):
320324
with self.assertLogs(LOGGER_NAME, level="WARNING") as cm:
321325
self.strategy.update_state_from_response(response, self.state)
322326

323-
self.assertTrue(any(f"unknown or stale read_id {unknown_id}" in output for output in cm.output))
324-
327+
self.assertTrue(
328+
any(
329+
f"unknown or stale read_id {unknown_id}" in output
330+
for output in cm.output
331+
)
332+
)
325333

326334
# --- Recovery Tests ---
327335

0 commit comments

Comments
 (0)