1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414"""
15+ from __future__ import annotations
16+
1517NOTE:
1618This is _experimental module for upcoming support for Rapid Storage.
1719(https://cloud.google.com/blog/products/storage-data-transfer/high-performance-storage-innovations-for-ai-hpc#:~:text=your%20AI%20workloads%3A-,Rapid%20Storage,-%3A%20A%20new)
2325"""
2426from io import BufferedReader , BytesIO
2527import asyncio
28+ import io
2629from typing import List , Optional , Tuple , Union
2730
2831from google_crc32c import Checksum
2932from google .api_core import exceptions
3033from google .api_core .retry_async import AsyncRetry
3134from google .rpc import status_pb2
3235from google .cloud ._storage_v2 .types import BidiWriteObjectRedirectedError
36+ from google .cloud ._storage_v2 .types .storage import BidiWriteObjectRequest
3337
3438
3539from ._utils import raise_if_no_fast_crc32c
5862
5963def _is_write_retryable (exc ):
6064 """Predicate to determine if a write operation should be retried."""
65+
66+ print ("In _is_write_retryable method, exception:" , exc )
67+
6168 if isinstance (
6269 exc ,
6370 (
@@ -192,6 +199,17 @@ def __init__(
192199 self .bytes_appended_since_last_flush = 0
193200 self ._lock = asyncio .Lock ()
194201 self ._routing_token : Optional [str ] = None
202+ self .object_resource : Optional [_storage_v2 .Object ] = None
203+
204+ def _stream_opener (self , write_handle = None ):
205+ """Helper to create a new _AsyncWriteObjectStream."""
206+ return _AsyncWriteObjectStream (
207+ client = self .client ,
208+ bucket_name = self .bucket_name ,
209+ object_name = self .object_name ,
210+ generation_number = self .generation ,
211+ write_handle = write_handle if write_handle else self .write_handle ,
212+ )
195213
196214 async def state_lookup (self ) -> int :
197215 """Returns the persisted_size
@@ -205,14 +223,15 @@ async def state_lookup(self) -> int:
205223 if not self ._is_stream_open :
206224 raise ValueError ("Stream is not open. Call open() before state_lookup()." )
207225
208- await self .write_obj_stream .send (
209- _storage_v2 .BidiWriteObjectRequest (
210- state_lookup = True ,
226+ async with self ._lock :
227+ await self .write_obj_stream .send (
228+ _storage_v2 .BidiWriteObjectRequest (
229+ state_lookup = True ,
230+ )
211231 )
212- )
213- response = await self .write_obj_stream .recv ()
214- self .persisted_size = response .persisted_size
215- return self .persisted_size
232+ response = await self .write_obj_stream .recv ()
233+ self .persisted_size = response .persisted_size
234+ return self .persisted_size
216235
217236 def _on_open_error (self , exc ):
218237 """Extracts routing token and write handle on redirect error during open."""
@@ -288,6 +307,7 @@ def combined_on_error(exc):
288307 ).with_on_error (combined_on_error )
289308
290309 async def _do_open ():
310+ print ("In _do_open method" )
291311 current_metadata = list (metadata ) if metadata else []
292312
293313 # Cleanup stream from previous failed attempt, if any.
@@ -314,6 +334,7 @@ async def _do_open():
314334 )
315335 self ._routing_token = None
316336
337+ print ("Current metadata in _do_open:" , current_metadata )
317338 await self .write_obj_stream .open (
318339 metadata = current_metadata if metadata else None
319340 )
@@ -327,108 +348,9 @@ async def _do_open():
327348
328349 self ._is_stream_open = True
329350
351+ print ("In open method, before retry_policy call" )
330352 await retry_policy (_do_open )()
331353
332- async def _upload_with_retry (
333- self ,
334- data : bytes ,
335- retry_policy : Optional [AsyncRetry ] = None ,
336- metadata : Optional [List [Tuple [str , str ]]] = None ,
337- ) -> None :
338- if not self ._is_stream_open :
339- raise ValueError ("Underlying bidi-gRPC stream is not open" )
340-
341- if retry_policy is None :
342- retry_policy = AsyncRetry (predicate = _is_write_retryable )
343-
344- # Initialize Global State for Retry Strategy
345- spec = _storage_v2 .AppendObjectSpec (
346- bucket = self .bucket_name ,
347- object = self .object_name ,
348- generation = self .generation ,
349- )
350- buffer = BytesIO (data )
351- write_state = _WriteState (
352- spec = spec ,
353- chunk_size = _MAX_CHUNK_SIZE_BYTES ,
354- user_buffer = buffer ,
355- )
356- write_state .write_handle = self .write_handle
357-
358- initial_state = {
359- "write_state" : write_state ,
360- "first_request" : True ,
361- }
362-
363- # Track attempts to manage stream reuse
364- attempt_count = 0
365-
366- def stream_opener (
367- requests ,
368- state ,
369- metadata : Optional [List [Tuple [str , str ]]] = None ,
370- ):
371- async def generator ():
372- nonlocal attempt_count
373- attempt_count += 1
374-
375- async with self ._lock :
376- current_handle = state ["write_state" ].write_handle
377- current_token = state ["write_state" ].routing_token
378-
379- should_reopen = (attempt_count > 1 ) or (current_token is not None )
380-
381- if should_reopen :
382- if self .write_obj_stream and self .write_obj_stream ._is_stream_open :
383- await self .write_obj_stream .close ()
384-
385- self .write_obj_stream = _AsyncWriteObjectStream (
386- client = self .client ,
387- bucket_name = self .bucket_name ,
388- object_name = self .object_name ,
389- generation_number = self .generation ,
390- write_handle = current_handle ,
391- )
392-
393- current_metadata = list (metadata ) if metadata else []
394- if current_token :
395- current_metadata .append (
396- (
397- "x-goog-request-params" ,
398- f"routing_token={ current_token } " ,
399- )
400- )
401-
402- await self .write_obj_stream .open (
403- metadata = current_metadata if current_metadata else None
404- )
405- self ._is_stream_open = True
406-
407- # Let the strategy generate the request sequence
408- async for request in requests :
409- await self .write_obj_stream .send (request )
410-
411- # Signal that we are done sending requests.
412- await self .write_obj_stream .requests .put (None )
413-
414- # Process responses
415- async for response in self .write_obj_stream :
416- yield response
417-
418- return generator ()
419-
420- strategy = _WriteResumptionStrategy ()
421- retry_manager = _BidiStreamRetryManager (
422- strategy , lambda r , s : stream_opener (r , s , metadata = metadata )
423- )
424-
425- await retry_manager .execute (initial_state , retry_policy )
426-
427- # Update the writer's state from the strategy's final state
428- final_write_state = initial_state ["write_state" ]
429- self .persisted_size = final_write_state .persisted_size
430- self .write_handle = final_write_state .write_handle
431- self .offset = self .persisted_size
432354
433355 async def append (
434356 self ,
@@ -460,9 +382,93 @@ async def append(
460382 if not self ._is_stream_open :
461383 raise ValueError ("Stream is not open. Call open() before append()." )
462384 if not data :
463- return # Do nothing for empty data
385+ return
386+
387+ if retry_policy is None :
388+ retry_policy = AsyncRetry (predicate = _is_write_retryable )
389+
390+ buffer = io .BytesIO (data )
391+ target_persisted_size = self .persisted_size + len (data )
392+ attempt_count = 0
393+
394+ print ("In append method" )
395+
396+ def send_and_recv_generator (requests : List [BidiWriteObjectRequest ], state : dict [str , _WriteState ], metadata : Optional [List [Tuple [str , str ]]] = None ):
397+ async def generator ():
398+ print ("In send_and_recv_generator" )
399+ nonlocal attempt_count
400+ attempt_count += 1
401+ resp = None
402+ async with self ._lock :
403+ write_state = state ["write_state" ]
404+ # If this is a retry or redirect, we must re-open the stream
405+ if attempt_count > 1 or write_state .routing_token :
406+ print ("Re-opening the stream inside send_and_recv_generator with attempt_count:" , attempt_count )
407+ if self .write_obj_stream and self .write_obj_stream .is_stream_open :
408+ await self .write_obj_stream .close ()
409+
410+ self .write_obj_stream = self ._stream_opener (write_handle = write_state .write_handle )
411+ current_metadata = list (metadata ) if metadata else []
412+ if write_state .routing_token :
413+ current_metadata .append (("x-goog-request-params" , f"routing_token={ write_state .routing_token } " ))
414+ await self .write_obj_stream .open (metadata = current_metadata if current_metadata else None )
415+
416+ self ._is_stream_open = True
417+ write_state .persisted_size = self .persisted_size
418+ write_state .write_handle = self .write_handle
419+
420+ print ("Sending requests in send_and_recv_generator" )
421+ # req_iter = iter(requests)
422+
423+ print ("Starting to send requests" )
424+ for i , chunk_req in enumerate (requests ):
425+ if i == len (requests ) - 1 :
426+ chunk_req .state_lookup = True
427+ print ("Sending chunk request" )
428+ await self .write_obj_stream .send (chunk_req )
429+ print ("Waiting to receive response" )
430+ print ("Current persisted_size:" , state ["write_state" ].persisted_size , "Target persisted_size:" , target_persisted_size )
431+
432+ resp = await self .write_obj_stream .recv ()
433+ if resp :
434+ if resp .persisted_size is not None :
435+ self .persisted_size = resp .persisted_size
436+ state ["write_state" ].persisted_size = resp .persisted_size
437+ if resp .write_handle :
438+ self .write_handle = resp .write_handle
439+ state ["write_state" ].write_handle = resp .write_handle
440+ print ("Received response in send_and_recv_generator" , resp )
441+
442+ yield resp
443+
444+ # while state["write_state"].persisted_size < target_persisted_size:
445+ # print("Waiting to receive response")
446+ # print("Current persisted_size:", state["write_state"].persisted_size, "Target persisted_size:", target_persisted_size)
447+ # resp = await self.write_obj_stream.recv()
448+ # print("Received response in send_and_recv_generator", resp)
449+ # if resp is None:
450+ # break
451+ # yield resp
452+ return generator ()
453+
454+ # State initialization
455+ spec = _storage_v2 .AppendObjectSpec (
456+ bucket = f"projects/_/buckets/{ self .bucket_name } " , object = self .object_name , generation = self .generation
457+ )
458+ write_state = _WriteState (spec , _MAX_CHUNK_SIZE_BYTES , buffer )
459+ write_state .write_handle = self .write_handle
460+ write_state .persisted_size = self .persisted_size
461+ write_state .bytes_sent = self .persisted_size
462+
463+ print ("Before creating retry manager" )
464+ retry_manager = _BidiStreamRetryManager (_WriteResumptionStrategy (),
465+ lambda r , s : send_and_recv_generator (r , s , metadata ))
466+ await retry_manager .execute ({"write_state" : write_state }, retry_policy )
467+
468+ # Sync local markers
469+ self .write_obj_stream .persisted_size = write_state .persisted_size
470+ self .write_obj_stream .write_handle = write_state .write_handle
464471
465- await self ._upload_with_retry (data , retry_policy , metadata )
466472
467473 async def simple_flush (self ) -> None :
468474 """Flushes the data to the server.
@@ -476,11 +482,12 @@ async def simple_flush(self) -> None:
476482 if not self ._is_stream_open :
477483 raise ValueError ("Stream is not open. Call open() before simple_flush()." )
478484
479- await self .write_obj_stream .send (
480- _storage_v2 .BidiWriteObjectRequest (
481- flush = True ,
485+ async with self ._lock :
486+ await self .write_obj_stream .send (
487+ _storage_v2 .BidiWriteObjectRequest (
488+ flush = True ,
489+ )
482490 )
483- )
484491
485492 async def flush (self ) -> int :
486493 """Flushes the data to the server.
@@ -494,16 +501,17 @@ async def flush(self) -> int:
494501 if not self ._is_stream_open :
495502 raise ValueError ("Stream is not open. Call open() before flush()." )
496503
497- await self .write_obj_stream .send (
498- _storage_v2 .BidiWriteObjectRequest (
499- flush = True ,
500- state_lookup = True ,
504+ async with self ._lock :
505+ await self .write_obj_stream .send (
506+ _storage_v2 .BidiWriteObjectRequest (
507+ flush = True ,
508+ state_lookup = True ,
509+ )
501510 )
502- )
503- response = await self .write_obj_stream .recv ()
504- self .persisted_size = response .persisted_size
505- self .offset = self .persisted_size
506- return self .persisted_size
511+ response = await self .write_obj_stream .recv ()
512+ self .persisted_size = response .persisted_size
513+ self .offset = self .persisted_size
514+ return self .persisted_size
507515
508516 async def close (self , finalize_on_close = False ) -> Union [int , _storage_v2 .Object ]:
509517 """Closes the underlying bidi-gRPC stream.
@@ -553,10 +561,16 @@ async def finalize(self) -> _storage_v2.Object:
553561 if not self ._is_stream_open :
554562 raise ValueError ("Stream is not open. Call open() before finalize()." )
555563
564+ print ("In finalize method" )
565+
566+ # async with self._lock:
567+ print ("Sending finish_write request" )
556568 await self .write_obj_stream .send (
557569 _storage_v2 .BidiWriteObjectRequest (finish_write = True )
558570 )
571+ print ("Waiting to receive response for finalize" )
559572 response = await self .write_obj_stream .recv ()
573+ print ("Received response for finalize:" )
560574 self .object_resource = response .resource
561575 self .persisted_size = self .object_resource .size
562576 await self .write_obj_stream .close ()
0 commit comments