@@ -711,9 +711,13 @@ class _ReadStorageStreamsSDF(beam.DoFn,
711711 def __init__ (
712712 self ,
713713 batch_arrow_read : bool = True ,
714- change_timestamp_column : str = 'change_timestamp' ) -> None :
714+ change_timestamp_column : str = 'change_timestamp' ,
715+ max_split_rounds : int = 1 ,
716+ emit_raw_batches : bool = False ) -> None :
715717 self ._batch_arrow_read = batch_arrow_read
716718 self ._change_timestamp_column = change_timestamp_column
719+ self ._max_split_rounds = max_split_rounds
720+ self ._emit_raw_batches = emit_raw_batches
717721 self ._storage_client = None
718722
719723 def _ensure_client (self ) -> None :
@@ -730,16 +734,80 @@ def _ensure_client(self) -> None:
730734 def setup (self ) -> None :
731735 self ._ensure_client ()
732736
737+ def _split_all_streams (
738+ self , stream_names : Tuple [str , ...],
739+ max_split_rounds : int ) -> Tuple [str , ...]:
740+ """Split each stream at fraction=0.5 for up to max_split_rounds rounds.
741+
742+ Each round attempts to split every stream in the current list. A
743+ successful split replaces the original stream with primary + remainder.
744+ A refused split (both fields empty) keeps the original stream intact.
745+ Stops when max_split_rounds is reached or a full round produces zero
746+ new splits.
747+
748+ BQ's server-side granularity controls how many splits are possible.
749+ Small tables may not split at all; large tables may allow multiple
750+ rounds of doubling.
751+ """
752+ result = list (stream_names )
753+ no_split = set ()
754+ for round_num in range (1 , max_split_rounds + 1 ):
755+ new_result = []
756+ made_progress = False
757+ for name in result :
758+ if name in no_split :
759+ new_result .append (name )
760+ continue
761+ response = self ._storage_client .split_read_stream (
762+ request = bq_storage .types .SplitReadStreamRequest (
763+ name = name , fraction = 0.5 ))
764+ primary = response .primary_stream .name
765+ remainder = response .remainder_stream .name
766+ if primary and remainder :
767+ new_result .extend ([primary , remainder ])
768+ made_progress = True
769+ else :
770+ new_result .append (name )
771+ no_split .add (name )
772+ result = new_result
773+ _LOGGER .info (
774+ '[Read] _split_all_streams round %d/%d: %d streams '
775+ '(progress=%s)' ,
776+ round_num ,
777+ max_split_rounds ,
778+ len (result ),
779+ made_progress )
780+ if not made_progress :
781+ break
782+ return tuple (result )
783+
733784 def initial_restriction (self , element : _QueryResult ) -> _StreamRestriction :
734- """Create ReadSession and return _StreamRestriction with stream names."""
785+ """Create ReadSession and return _StreamRestriction with stream names.
786+
787+ When max_split_rounds > 0, uses SplitReadStream to subdivide each
788+ stream at fraction=0.5 for up to max_split_rounds rounds, maximizing
789+ parallelism beyond what CreateReadSession provides.
790+ """
735791 self ._ensure_client ()
736792 table_key = bigquery_tools .get_hashable_destination (element .temp_table_ref )
737793 session = self ._create_read_session (element .temp_table_ref )
738794 stream_names = tuple (s .name for s in session .streams )
795+ original_count = len (stream_names )
739796 _LOGGER .info (
740- '[Read] initial_restriction for %s: %d streams' ,
797+ '[Read] initial_restriction for %s: %d streams from CreateReadSession ' ,
741798 table_key ,
742- len (stream_names ))
799+ original_count )
800+
801+ if self ._max_split_rounds > 0 :
802+ stream_names = self ._split_all_streams (
803+ stream_names , self ._max_split_rounds )
804+ _LOGGER .info (
805+ '[Read] initial_restriction for %s: %d -> %d streams '
806+ 'after SplitReadStream' ,
807+ table_key ,
808+ original_count ,
809+ len (stream_names ))
810+
743811 return _StreamRestriction (stream_names , 0 , len (stream_names ))
744812
745813 def create_tracker (
@@ -767,8 +835,7 @@ def process(
767835 element : _QueryResult ,
768836 restriction_tracker = beam .DoFn .RestrictionParam (),
769837 watermark_estimator = beam .DoFn .WatermarkEstimatorParam (
770- _CDCWatermarkEstimatorProvider ())
771- ) -> Iterable [Dict [str , Any ]]:
838+ _CDCWatermarkEstimatorProvider ())):
772839 self ._ensure_client ()
773840 table_key = bigquery_tools .get_hashable_destination (element .temp_table_ref )
774841
@@ -785,7 +852,6 @@ def process(
785852 total_streams = len (stream_names )
786853
787854 streams_read = 0
788- total_rows = 0
789855
790856 _LOGGER .info (
791857 '[Read] Reading streams [%d, %d) of %d total for %s' ,
@@ -808,19 +874,27 @@ def process(
808874 '[Read] try_claim(%d) succeeded: reading stream %s' , i , stream_name )
809875
810876 stream_rows = 0
811- for row in self ._read_stream (stream_name ):
812- ts = row .get (self ._change_timestamp_column )
813- if ts is None :
814- raise ValueError (
815- 'Row missing %r column. Row keys: %s' %
816- (self ._change_timestamp_column , list (row .keys ())))
817- if isinstance (ts , datetime .datetime ):
818- ts = Timestamp .from_utc_datetime (ts )
819-
820- yield TimestampedValue (row , ts )
821- stream_rows += 1
822- total_rows += 1
823- Metrics .counter ('BigQueryChangeHistory' , 'rows_emitted' ).inc (total_rows )
877+ if self ._emit_raw_batches :
878+ stream_batches = 0
879+ for raw_batch in self ._read_stream_raw (stream_name ):
880+ yield TimestampedValue (raw_batch , element .range_start )
881+ stream_batches += 1
882+ Metrics .counter ('BigQueryChangeHistory' ,
883+ 'batches_emitted' ).inc (stream_batches )
884+ else :
885+ for row in self ._read_stream (stream_name ):
886+ ts = row .get (self ._change_timestamp_column )
887+ if ts is None :
888+ raise ValueError (
889+ 'Row missing %r column. Row keys: %s' %
890+ (self ._change_timestamp_column , list (row .keys ())))
891+ if isinstance (ts , datetime .datetime ):
892+ ts = Timestamp .from_utc_datetime (ts )
893+
894+ yield TimestampedValue (row , ts )
895+ stream_rows += 1
896+ Metrics .counter ('BigQueryChangeHistory' ,
897+ 'rows_emitted' ).inc (stream_rows )
824898
825899 streams_read += 1
826900 _LOGGER .info (
@@ -838,16 +912,19 @@ def process(
838912 _utc (element .range_end ),
839913 table_key )
840914
915+ # Release the storage client so the gRPC channel doesn't go stale
916+ # between process() calls. _ensure_client() will create a fresh one.
917+ self ._storage_client = None
918+
841919 # Emit cleanup signal. Every split that reads at least one stream
842920 # reports how many it read.
843921 if streams_read > 0 :
844922 _LOGGER .info (
845923 '[Read] Emitting cleanup signal for %s: '
846- 'streams_read=%d, total_streams=%d, total_rows=%d ' ,
924+ 'streams_read=%d, total_streams=%d' ,
847925 table_key ,
848926 streams_read ,
849- total_streams ,
850- total_rows )
927+ total_streams )
851928 yield beam .pvalue .TaggedOutput (
852929 _CLEANUP_TAG , (table_key , (streams_read , total_streams )))
853930
@@ -863,7 +940,7 @@ def _create_read_session(self, table_ref: 'bigquery.TableReference') -> Any:
863940 requested_session .data_format = bq_storage .types .DataFormat .ARROW
864941 read_options = requested_session .read_options
865942 read_options .arrow_serialization_options .buffer_compression = (
866- bq_storage .types .ArrowSerializationOptions .CompressionCodec .LZ4_FRAME )
943+ bq_storage .types .ArrowSerializationOptions .CompressionCodec .ZSTD )
867944
868945 session = self ._storage_client .create_read_session (
869946 parent = f'projects/{ table_ref .projectId } ' ,
@@ -879,7 +956,7 @@ def _read_stream(self, stream_name: str) -> Iterable[Dict[str, Any]]:
879956 """Read all rows from a single Storage API stream as dicts.
880957
881958 When batch_arrow_read is enabled, converts entire Arrow RecordBatches
882- at once using to_pydict () instead of calling .as_py() on each cell
959+ at once using to_pylist () instead of calling .as_py() on each cell
883960 individually. This is ~1.5x faster for large tables at the cost of ~2x
884961 peak memory per batch.
885962 """
@@ -925,6 +1002,56 @@ def _read_stream_batch(self, stream_name: str) -> Iterable[Dict[str, Any]]:
9251002 elapsed ,
9261003 row_count / elapsed if elapsed > 0 else 0 )
9271004
1005+ def _read_stream_raw (self , stream_name : str ) -> Iterable [Tuple [bytes , bytes ]]:
1006+ """Yield raw (schema_bytes, batch_bytes) without decompression.
1007+
1008+ Used when emit_raw_batches is enabled to defer decompression and
1009+ Arrow-to-Python conversion to a downstream DoFn after reshuffling.
1010+ Schema bytes are included in each tuple so each batch is
1011+ self-contained and can be decoded independently.
1012+ """
1013+ schema_bytes = b''
1014+ batch_count = 0
1015+ t0 = time .time ()
1016+ for response in self ._storage_client .read_rows (stream_name ):
1017+ if not schema_bytes and response .arrow_schema .serialized_schema :
1018+ schema_bytes = bytes (response .arrow_schema .serialized_schema )
1019+ batch_bytes = response .arrow_record_batch .serialized_record_batch
1020+ if batch_bytes and schema_bytes :
1021+ yield (schema_bytes , bytes (batch_bytes ))
1022+ batch_count += 1
1023+ elapsed = time .time () - t0
1024+ _LOGGER .info ('[Read] raw_read: %d batches in %.2fs' , batch_count , elapsed )
1025+
1026+
1027+ class _DecompressArrowBatchesFn (beam .DoFn ):
1028+ """Decompress and convert raw Arrow batches to timestamped row dicts.
1029+
1030+ Receives individual (schema_bytes, batch_bytes) tuples after Reshuffle
1031+ and converts each batch to individual row dicts with event timestamps
1032+ extracted from the change_timestamp column.
1033+ """
1034+ def __init__ (self , change_timestamp_column : str = 'change_timestamp' ) -> None :
1035+ self ._change_timestamp_column = change_timestamp_column
1036+
1037+ def process (self , element : Tuple [bytes , bytes ]) -> Iterable [Dict [str , Any ]]:
1038+ schema_bytes , batch_bytes = element
1039+ schema = pyarrow .ipc .read_schema (pyarrow .py_buffer (schema_bytes ))
1040+ batch = pyarrow .ipc .read_record_batch (
1041+ pyarrow .py_buffer (batch_bytes ), schema )
1042+
1043+ rows = batch .to_pylist ()
1044+ for row in rows :
1045+ ts = row .get (self ._change_timestamp_column )
1046+ if ts is None :
1047+ raise ValueError (
1048+ 'Row missing %r column. Row keys: %s' %
1049+ (self ._change_timestamp_column , list (row .keys ())))
1050+ if isinstance (ts , datetime .datetime ):
1051+ ts = Timestamp .from_utc_datetime (ts )
1052+ yield TimestampedValue (row , ts )
1053+ Metrics .counter ('BigQueryChangeHistory' , 'rows_emitted' ).inc (len (rows ))
1054+
9281055
9291056# =============================================================================
9301057# Cleanup: _CleanupTempTablesFn
@@ -1038,9 +1165,21 @@ class ReadBigQueryChangeHistory(beam.PTransform):
10381165 on the CHANGES/APPENDS query. Do not include the WHERE keyword.
10391166 Example: ``'status = "active" AND region = "US"'``.
10401167 batch_arrow_read: If True (default), convert Arrow RecordBatches in
1041- bulk using to_pydict () instead of per-cell .as_py() calls.
1168+ bulk using to_pylist () instead of per-cell .as_py() calls.
10421169 This is 1.5x faster for large tables at the cost of ~2x peak
10431170 memory per RecordBatch. Set to False for minimal memory usage.
1171+ max_split_rounds: Maximum number of recursive SplitReadStream
1172+ rounds. Each round splits every stream at fraction=0.5,
1173+ potentially doubling the stream count (if BQ allows). Default
1174+ 1 (one round of splitting). Set 0 to disable splitting
1175+ entirely. Set higher for very large tables where more
1176+ parallelism is needed.
1177+ reshuffle_decompress: If True (default), the Read SDF emits raw
1178+ compressed Arrow batches instead of decoded rows. The batches
1179+ are reshuffled for fan-out and then decoded in a separate DoFn.
1180+ This spreads decompression and Arrow-to-Python conversion CPU
1181+ across more workers. Set to False to decode rows inline within
1182+ the Read SDF.
10441183 """
10451184 def __init__ (
10461185 self ,
@@ -1057,7 +1196,9 @@ def __init__(
10571196 change_timestamp_column : str = 'change_timestamp' ,
10581197 columns : Optional [List [str ]] = None ,
10591198 row_filter : Optional [str ] = None ,
1060- batch_arrow_read : bool = True ) -> None :
1199+ batch_arrow_read : bool = True ,
1200+ max_split_rounds : int = 1 ,
1201+ reshuffle_decompress : bool = True ) -> None :
10611202 super ().__init__ ()
10621203 if bq_storage is None :
10631204 raise ImportError (
@@ -1091,6 +1232,8 @@ def __init__(
10911232 self ._columns = columns
10921233 self ._row_filter = row_filter
10931234 self ._batch_arrow_read = batch_arrow_read
1235+ self ._max_split_rounds = max_split_rounds
1236+ self ._reshuffle_decompress = reshuffle_decompress
10941237
10951238 def expand (self , pbegin : beam .pvalue .PBegin ) -> beam .PCollection :
10961239 project = self ._project
@@ -1170,16 +1313,37 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection:
11701313 row_filter = self ._row_filter ))
11711314 | 'CommitQueryResults' >> beam .Reshuffle ())
11721315
1316+ emit_raw = self ._reshuffle_decompress
1317+
1318+ read_sdf = beam .ParDo (
1319+ _ReadStorageStreamsSDF (
1320+ batch_arrow_read = self ._batch_arrow_read ,
1321+ change_timestamp_column = self ._change_timestamp_column ,
1322+ max_split_rounds = self ._max_split_rounds ,
1323+ emit_raw_batches = emit_raw ))
1324+ if emit_raw :
1325+ read_sdf = read_sdf .with_output_types (Tuple [bytes , bytes ])
1326+ else :
1327+ read_sdf = read_sdf .with_output_types (Dict [str , Any ])
1328+
11731329 read_outputs = (
11741330 query_results
1175- | 'ReadStorageStreams' >> beam .ParDo (
1176- _ReadStorageStreamsSDF (
1177- batch_arrow_read = self ._batch_arrow_read ,
1178- change_timestamp_column = self ._change_timestamp_column )).
1179- with_outputs (_CLEANUP_TAG , main = 'rows' ))
1331+ | 'ReadStorageStreams' >> read_sdf .with_outputs (
1332+ _CLEANUP_TAG , main = 'rows' ))
11801333
11811334 _ = (
11821335 read_outputs [_CLEANUP_TAG ]
11831336 | 'CleanupTempTables' >> beam .ParDo (_CleanupTempTablesFn ()))
11841337
1185- return read_outputs ['rows' ]
1338+ if emit_raw :
1339+ # Reshuffle raw Arrow batches for fan-out, then decompress and
1340+ # convert to timestamped row dicts in a separate DoFn.
1341+ rows = (
1342+ read_outputs ['rows' ]
1343+ | 'ReshuffleForFanout' >> beam .Reshuffle ()
1344+ | 'DecompressBatches' >> beam .ParDo (
1345+ _DecompressArrowBatchesFn (
1346+ change_timestamp_column = (self ._change_timestamp_column ))))
1347+ return rows
1348+ else :
1349+ return read_outputs ['rows' ]
0 commit comments