Skip to content

Commit ccf667a

Browse files
committed
more changes
1 parent d1cc1ef commit ccf667a

File tree

5 files changed

+463
-145
lines changed

5 files changed

+463
-145
lines changed

google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py

Lines changed: 136 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""
15+
from __future__ import annotations
16+
1517
NOTE:
1618
This is _experimental module for upcoming support for Rapid Storage.
1719
(https://cloud.google.com/blog/products/storage-data-transfer/high-performance-storage-innovations-for-ai-hpc#:~:text=your%20AI%20workloads%3A-,Rapid%20Storage,-%3A%20A%20new)
@@ -23,13 +25,15 @@
2325
"""
2426
from io import BufferedReader, BytesIO
2527
import asyncio
28+
import io
2629
from typing import List, Optional, Tuple, Union
2730

2831
from google_crc32c import Checksum
2932
from google.api_core import exceptions
3033
from google.api_core.retry_async import AsyncRetry
3134
from google.rpc import status_pb2
3235
from google.cloud._storage_v2.types import BidiWriteObjectRedirectedError
36+
from google.cloud._storage_v2.types.storage import BidiWriteObjectRequest
3337

3438

3539
from ._utils import raise_if_no_fast_crc32c
@@ -58,6 +62,9 @@
5862

5963
def _is_write_retryable(exc):
6064
"""Predicate to determine if a write operation should be retried."""
65+
66+
print("In _is_write_retryable method, exception:", exc)
67+
6168
if isinstance(
6269
exc,
6370
(
@@ -192,6 +199,17 @@ def __init__(
192199
self.bytes_appended_since_last_flush = 0
193200
self._lock = asyncio.Lock()
194201
self._routing_token: Optional[str] = None
202+
self.object_resource: Optional[_storage_v2.Object] = None
203+
204+
def _stream_opener(self, write_handle=None):
205+
"""Helper to create a new _AsyncWriteObjectStream."""
206+
return _AsyncWriteObjectStream(
207+
client=self.client,
208+
bucket_name=self.bucket_name,
209+
object_name=self.object_name,
210+
generation_number=self.generation,
211+
write_handle=write_handle if write_handle else self.write_handle,
212+
)
195213

196214
async def state_lookup(self) -> int:
197215
"""Returns the persisted_size
@@ -205,14 +223,15 @@ async def state_lookup(self) -> int:
205223
if not self._is_stream_open:
206224
raise ValueError("Stream is not open. Call open() before state_lookup().")
207225

208-
await self.write_obj_stream.send(
209-
_storage_v2.BidiWriteObjectRequest(
210-
state_lookup=True,
226+
async with self._lock:
227+
await self.write_obj_stream.send(
228+
_storage_v2.BidiWriteObjectRequest(
229+
state_lookup=True,
230+
)
211231
)
212-
)
213-
response = await self.write_obj_stream.recv()
214-
self.persisted_size = response.persisted_size
215-
return self.persisted_size
232+
response = await self.write_obj_stream.recv()
233+
self.persisted_size = response.persisted_size
234+
return self.persisted_size
216235

217236
def _on_open_error(self, exc):
218237
"""Extracts routing token and write handle on redirect error during open."""
@@ -288,6 +307,7 @@ def combined_on_error(exc):
288307
).with_on_error(combined_on_error)
289308

290309
async def _do_open():
310+
print("In _do_open method")
291311
current_metadata = list(metadata) if metadata else []
292312

293313
# Cleanup stream from previous failed attempt, if any.
@@ -314,6 +334,7 @@ async def _do_open():
314334
)
315335
self._routing_token = None
316336

337+
print("Current metadata in _do_open:", current_metadata)
317338
await self.write_obj_stream.open(
318339
metadata=current_metadata if metadata else None
319340
)
@@ -327,108 +348,9 @@ async def _do_open():
327348

328349
self._is_stream_open = True
329350

351+
print("In open method, before retry_policy call")
330352
await retry_policy(_do_open)()
331353

332-
async def _upload_with_retry(
333-
self,
334-
data: bytes,
335-
retry_policy: Optional[AsyncRetry] = None,
336-
metadata: Optional[List[Tuple[str, str]]] = None,
337-
) -> None:
338-
if not self._is_stream_open:
339-
raise ValueError("Underlying bidi-gRPC stream is not open")
340-
341-
if retry_policy is None:
342-
retry_policy = AsyncRetry(predicate=_is_write_retryable)
343-
344-
# Initialize Global State for Retry Strategy
345-
spec = _storage_v2.AppendObjectSpec(
346-
bucket=self.bucket_name,
347-
object=self.object_name,
348-
generation=self.generation,
349-
)
350-
buffer = BytesIO(data)
351-
write_state = _WriteState(
352-
spec=spec,
353-
chunk_size=_MAX_CHUNK_SIZE_BYTES,
354-
user_buffer=buffer,
355-
)
356-
write_state.write_handle = self.write_handle
357-
358-
initial_state = {
359-
"write_state": write_state,
360-
"first_request": True,
361-
}
362-
363-
# Track attempts to manage stream reuse
364-
attempt_count = 0
365-
366-
def stream_opener(
367-
requests,
368-
state,
369-
metadata: Optional[List[Tuple[str, str]]] = None,
370-
):
371-
async def generator():
372-
nonlocal attempt_count
373-
attempt_count += 1
374-
375-
async with self._lock:
376-
current_handle = state["write_state"].write_handle
377-
current_token = state["write_state"].routing_token
378-
379-
should_reopen = (attempt_count > 1) or (current_token is not None)
380-
381-
if should_reopen:
382-
if self.write_obj_stream and self.write_obj_stream._is_stream_open:
383-
await self.write_obj_stream.close()
384-
385-
self.write_obj_stream = _AsyncWriteObjectStream(
386-
client=self.client,
387-
bucket_name=self.bucket_name,
388-
object_name=self.object_name,
389-
generation_number=self.generation,
390-
write_handle=current_handle,
391-
)
392-
393-
current_metadata = list(metadata) if metadata else []
394-
if current_token:
395-
current_metadata.append(
396-
(
397-
"x-goog-request-params",
398-
f"routing_token={current_token}",
399-
)
400-
)
401-
402-
await self.write_obj_stream.open(
403-
metadata=current_metadata if current_metadata else None
404-
)
405-
self._is_stream_open = True
406-
407-
# Let the strategy generate the request sequence
408-
async for request in requests:
409-
await self.write_obj_stream.send(request)
410-
411-
# Signal that we are done sending requests.
412-
await self.write_obj_stream.requests.put(None)
413-
414-
# Process responses
415-
async for response in self.write_obj_stream:
416-
yield response
417-
418-
return generator()
419-
420-
strategy = _WriteResumptionStrategy()
421-
retry_manager = _BidiStreamRetryManager(
422-
strategy, lambda r, s: stream_opener(r, s, metadata=metadata)
423-
)
424-
425-
await retry_manager.execute(initial_state, retry_policy)
426-
427-
# Update the writer's state from the strategy's final state
428-
final_write_state = initial_state["write_state"]
429-
self.persisted_size = final_write_state.persisted_size
430-
self.write_handle = final_write_state.write_handle
431-
self.offset = self.persisted_size
432354

433355
async def append(
434356
self,
@@ -460,9 +382,93 @@ async def append(
460382
if not self._is_stream_open:
461383
raise ValueError("Stream is not open. Call open() before append().")
462384
if not data:
463-
return # Do nothing for empty data
385+
return
386+
387+
if retry_policy is None:
388+
retry_policy = AsyncRetry(predicate=_is_write_retryable)
389+
390+
buffer = io.BytesIO(data)
391+
target_persisted_size = self.persisted_size + len(data)
392+
attempt_count = 0
393+
394+
print("In append method")
395+
396+
def send_and_recv_generator(requests: List[BidiWriteObjectRequest], state: dict[str, _WriteState], metadata: Optional[List[Tuple[str, str]]] = None):
397+
async def generator():
398+
print("In send_and_recv_generator")
399+
nonlocal attempt_count
400+
attempt_count += 1
401+
resp = None
402+
async with self._lock:
403+
write_state = state["write_state"]
404+
# If this is a retry or redirect, we must re-open the stream
405+
if attempt_count > 1 or write_state.routing_token:
406+
print("Re-opening the stream inside send_and_recv_generator with attempt_count:", attempt_count)
407+
if self.write_obj_stream and self.write_obj_stream.is_stream_open:
408+
await self.write_obj_stream.close()
409+
410+
self.write_obj_stream = self._stream_opener(write_handle=write_state.write_handle)
411+
current_metadata = list(metadata) if metadata else []
412+
if write_state.routing_token:
413+
current_metadata.append(("x-goog-request-params", f"routing_token={write_state.routing_token}"))
414+
await self.write_obj_stream.open(metadata=current_metadata if current_metadata else None)
415+
416+
self._is_stream_open = True
417+
write_state.persisted_size = self.persisted_size
418+
write_state.write_handle = self.write_handle
419+
420+
print("Sending requests in send_and_recv_generator")
421+
# req_iter = iter(requests)
422+
423+
print("Starting to send requests")
424+
for i, chunk_req in enumerate(requests):
425+
if i == len(requests) - 1:
426+
chunk_req.state_lookup = True
427+
print("Sending chunk request")
428+
await self.write_obj_stream.send(chunk_req)
429+
print("Waiting to receive response")
430+
print("Current persisted_size:", state["write_state"].persisted_size, "Target persisted_size:", target_persisted_size)
431+
432+
resp = await self.write_obj_stream.recv()
433+
if resp:
434+
if resp.persisted_size is not None:
435+
self.persisted_size = resp.persisted_size
436+
state["write_state"].persisted_size = resp.persisted_size
437+
if resp.write_handle:
438+
self.write_handle = resp.write_handle
439+
state["write_state"].write_handle = resp.write_handle
440+
print("Received response in send_and_recv_generator", resp)
441+
442+
yield resp
443+
444+
# while state["write_state"].persisted_size < target_persisted_size:
445+
# print("Waiting to receive response")
446+
# print("Current persisted_size:", state["write_state"].persisted_size, "Target persisted_size:", target_persisted_size)
447+
# resp = await self.write_obj_stream.recv()
448+
# print("Received response in send_and_recv_generator", resp)
449+
# if resp is None:
450+
# break
451+
# yield resp
452+
return generator()
453+
454+
# State initialization
455+
spec = _storage_v2.AppendObjectSpec(
456+
bucket=f"projects/_/buckets/{self.bucket_name}", object=self.object_name, generation=self.generation
457+
)
458+
write_state = _WriteState(spec, _MAX_CHUNK_SIZE_BYTES, buffer)
459+
write_state.write_handle = self.write_handle
460+
write_state.persisted_size = self.persisted_size
461+
write_state.bytes_sent = self.persisted_size
462+
463+
print("Before creating retry manager")
464+
retry_manager = _BidiStreamRetryManager(_WriteResumptionStrategy(),
465+
lambda r, s: send_and_recv_generator(r, s, metadata))
466+
await retry_manager.execute({"write_state": write_state}, retry_policy)
467+
468+
# Sync local markers
469+
self.write_obj_stream.persisted_size = write_state.persisted_size
470+
self.write_obj_stream.write_handle = write_state.write_handle
464471

465-
await self._upload_with_retry(data, retry_policy, metadata)
466472

467473
async def simple_flush(self) -> None:
468474
"""Flushes the data to the server.
@@ -476,11 +482,12 @@ async def simple_flush(self) -> None:
476482
if not self._is_stream_open:
477483
raise ValueError("Stream is not open. Call open() before simple_flush().")
478484

479-
await self.write_obj_stream.send(
480-
_storage_v2.BidiWriteObjectRequest(
481-
flush=True,
485+
async with self._lock:
486+
await self.write_obj_stream.send(
487+
_storage_v2.BidiWriteObjectRequest(
488+
flush=True,
489+
)
482490
)
483-
)
484491

485492
async def flush(self) -> int:
486493
"""Flushes the data to the server.
@@ -494,16 +501,17 @@ async def flush(self) -> int:
494501
if not self._is_stream_open:
495502
raise ValueError("Stream is not open. Call open() before flush().")
496503

497-
await self.write_obj_stream.send(
498-
_storage_v2.BidiWriteObjectRequest(
499-
flush=True,
500-
state_lookup=True,
504+
async with self._lock:
505+
await self.write_obj_stream.send(
506+
_storage_v2.BidiWriteObjectRequest(
507+
flush=True,
508+
state_lookup=True,
509+
)
501510
)
502-
)
503-
response = await self.write_obj_stream.recv()
504-
self.persisted_size = response.persisted_size
505-
self.offset = self.persisted_size
506-
return self.persisted_size
511+
response = await self.write_obj_stream.recv()
512+
self.persisted_size = response.persisted_size
513+
self.offset = self.persisted_size
514+
return self.persisted_size
507515

508516
async def close(self, finalize_on_close=False) -> Union[int, _storage_v2.Object]:
509517
"""Closes the underlying bidi-gRPC stream.
@@ -553,10 +561,16 @@ async def finalize(self) -> _storage_v2.Object:
553561
if not self._is_stream_open:
554562
raise ValueError("Stream is not open. Call open() before finalize().")
555563

564+
print("In finalize method")
565+
566+
# async with self._lock:
567+
print("Sending finish_write request")
556568
await self.write_obj_stream.send(
557569
_storage_v2.BidiWriteObjectRequest(finish_write=True)
558570
)
571+
print("Waiting to receive response for finalize")
559572
response = await self.write_obj_stream.recv()
573+
print("Received response for finalize:")
560574
self.object_resource = response.resource
561575
self.persisted_size = self.object_resource.size
562576
await self.write_obj_stream.close()

0 commit comments

Comments
 (0)