Skip to content

Commit b39a95f

Browse files
authored
Improve bq cdc scalability (apache#38015)
* Recursively split read streams as far as possible * improve scalability * comments
1 parent ee02db8 commit b39a95f

File tree

2 files changed

+281
-52
lines changed

2 files changed

+281
-52
lines changed

sdks/python/apache_beam/io/gcp/bigquery_change_history.py

Lines changed: 197 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -711,9 +711,13 @@ class _ReadStorageStreamsSDF(beam.DoFn,
711711
def __init__(
712712
self,
713713
batch_arrow_read: bool = True,
714-
change_timestamp_column: str = 'change_timestamp') -> None:
714+
change_timestamp_column: str = 'change_timestamp',
715+
max_split_rounds: int = 1,
716+
emit_raw_batches: bool = False) -> None:
715717
self._batch_arrow_read = batch_arrow_read
716718
self._change_timestamp_column = change_timestamp_column
719+
self._max_split_rounds = max_split_rounds
720+
self._emit_raw_batches = emit_raw_batches
717721
self._storage_client = None
718722

719723
def _ensure_client(self) -> None:
@@ -730,16 +734,80 @@ def _ensure_client(self) -> None:
730734
def setup(self) -> None:
731735
self._ensure_client()
732736

737+
def _split_all_streams(
738+
self, stream_names: Tuple[str, ...],
739+
max_split_rounds: int) -> Tuple[str, ...]:
740+
"""Split each stream at fraction=0.5 for up to max_split_rounds rounds.
741+
742+
Each round attempts to split every stream in the current list. A
743+
successful split replaces the original stream with primary + remainder.
744+
A refused split (both fields empty) keeps the original stream intact.
745+
Stops when max_split_rounds is reached or a full round produces zero
746+
new splits.
747+
748+
BQ's server-side granularity controls how many splits are possible.
749+
Small tables may not split at all; large tables may allow multiple
750+
rounds of doubling.
751+
"""
752+
result = list(stream_names)
753+
no_split = set()
754+
for round_num in range(1, max_split_rounds + 1):
755+
new_result = []
756+
made_progress = False
757+
for name in result:
758+
if name in no_split:
759+
new_result.append(name)
760+
continue
761+
response = self._storage_client.split_read_stream(
762+
request=bq_storage.types.SplitReadStreamRequest(
763+
name=name, fraction=0.5))
764+
primary = response.primary_stream.name
765+
remainder = response.remainder_stream.name
766+
if primary and remainder:
767+
new_result.extend([primary, remainder])
768+
made_progress = True
769+
else:
770+
new_result.append(name)
771+
no_split.add(name)
772+
result = new_result
773+
_LOGGER.info(
774+
'[Read] _split_all_streams round %d/%d: %d streams '
775+
'(progress=%s)',
776+
round_num,
777+
max_split_rounds,
778+
len(result),
779+
made_progress)
780+
if not made_progress:
781+
break
782+
return tuple(result)
783+
733784
def initial_restriction(self, element: _QueryResult) -> _StreamRestriction:
734-
"""Create ReadSession and return _StreamRestriction with stream names."""
785+
"""Create ReadSession and return _StreamRestriction with stream names.
786+
787+
When max_split_rounds > 0, uses SplitReadStream to subdivide each
788+
stream at fraction=0.5 for up to max_split_rounds rounds, maximizing
789+
parallelism beyond what CreateReadSession provides.
790+
"""
735791
self._ensure_client()
736792
table_key = bigquery_tools.get_hashable_destination(element.temp_table_ref)
737793
session = self._create_read_session(element.temp_table_ref)
738794
stream_names = tuple(s.name for s in session.streams)
795+
original_count = len(stream_names)
739796
_LOGGER.info(
740-
'[Read] initial_restriction for %s: %d streams',
797+
'[Read] initial_restriction for %s: %d streams from CreateReadSession',
741798
table_key,
742-
len(stream_names))
799+
original_count)
800+
801+
if self._max_split_rounds > 0:
802+
stream_names = self._split_all_streams(
803+
stream_names, self._max_split_rounds)
804+
_LOGGER.info(
805+
'[Read] initial_restriction for %s: %d -> %d streams '
806+
'after SplitReadStream',
807+
table_key,
808+
original_count,
809+
len(stream_names))
810+
743811
return _StreamRestriction(stream_names, 0, len(stream_names))
744812

745813
def create_tracker(
@@ -767,8 +835,7 @@ def process(
767835
element: _QueryResult,
768836
restriction_tracker=beam.DoFn.RestrictionParam(),
769837
watermark_estimator=beam.DoFn.WatermarkEstimatorParam(
770-
_CDCWatermarkEstimatorProvider())
771-
) -> Iterable[Dict[str, Any]]:
838+
_CDCWatermarkEstimatorProvider())):
772839
self._ensure_client()
773840
table_key = bigquery_tools.get_hashable_destination(element.temp_table_ref)
774841

@@ -785,7 +852,6 @@ def process(
785852
total_streams = len(stream_names)
786853

787854
streams_read = 0
788-
total_rows = 0
789855

790856
_LOGGER.info(
791857
'[Read] Reading streams [%d, %d) of %d total for %s',
@@ -808,19 +874,27 @@ def process(
808874
'[Read] try_claim(%d) succeeded: reading stream %s', i, stream_name)
809875

810876
stream_rows = 0
811-
for row in self._read_stream(stream_name):
812-
ts = row.get(self._change_timestamp_column)
813-
if ts is None:
814-
raise ValueError(
815-
'Row missing %r column. Row keys: %s' %
816-
(self._change_timestamp_column, list(row.keys())))
817-
if isinstance(ts, datetime.datetime):
818-
ts = Timestamp.from_utc_datetime(ts)
819-
820-
yield TimestampedValue(row, ts)
821-
stream_rows += 1
822-
total_rows += 1
823-
Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc(total_rows)
877+
if self._emit_raw_batches:
878+
stream_batches = 0
879+
for raw_batch in self._read_stream_raw(stream_name):
880+
yield TimestampedValue(raw_batch, element.range_start)
881+
stream_batches += 1
882+
Metrics.counter('BigQueryChangeHistory',
883+
'batches_emitted').inc(stream_batches)
884+
else:
885+
for row in self._read_stream(stream_name):
886+
ts = row.get(self._change_timestamp_column)
887+
if ts is None:
888+
raise ValueError(
889+
'Row missing %r column. Row keys: %s' %
890+
(self._change_timestamp_column, list(row.keys())))
891+
if isinstance(ts, datetime.datetime):
892+
ts = Timestamp.from_utc_datetime(ts)
893+
894+
yield TimestampedValue(row, ts)
895+
stream_rows += 1
896+
Metrics.counter('BigQueryChangeHistory',
897+
'rows_emitted').inc(stream_rows)
824898

825899
streams_read += 1
826900
_LOGGER.info(
@@ -838,16 +912,19 @@ def process(
838912
_utc(element.range_end),
839913
table_key)
840914

915+
# Release the storage client so the gRPC channel doesn't go stale
916+
# between process() calls. _ensure_client() will create a fresh one.
917+
self._storage_client = None
918+
841919
# Emit cleanup signal. Every split that reads at least one stream
842920
# reports how many it read.
843921
if streams_read > 0:
844922
_LOGGER.info(
845923
'[Read] Emitting cleanup signal for %s: '
846-
'streams_read=%d, total_streams=%d, total_rows=%d',
924+
'streams_read=%d, total_streams=%d',
847925
table_key,
848926
streams_read,
849-
total_streams,
850-
total_rows)
927+
total_streams)
851928
yield beam.pvalue.TaggedOutput(
852929
_CLEANUP_TAG, (table_key, (streams_read, total_streams)))
853930

@@ -863,7 +940,7 @@ def _create_read_session(self, table_ref: 'bigquery.TableReference') -> Any:
863940
requested_session.data_format = bq_storage.types.DataFormat.ARROW
864941
read_options = requested_session.read_options
865942
read_options.arrow_serialization_options.buffer_compression = (
866-
bq_storage.types.ArrowSerializationOptions.CompressionCodec.LZ4_FRAME)
943+
bq_storage.types.ArrowSerializationOptions.CompressionCodec.ZSTD)
867944

868945
session = self._storage_client.create_read_session(
869946
parent=f'projects/{table_ref.projectId}',
@@ -879,7 +956,7 @@ def _read_stream(self, stream_name: str) -> Iterable[Dict[str, Any]]:
879956
"""Read all rows from a single Storage API stream as dicts.
880957
881958
When batch_arrow_read is enabled, converts entire Arrow RecordBatches
882-
at once using to_pydict() instead of calling .as_py() on each cell
959+
at once using to_pylist() instead of calling .as_py() on each cell
883960
individually. This is ~1.5x faster for large tables at the cost of ~2x
884961
peak memory per batch.
885962
"""
@@ -925,6 +1002,56 @@ def _read_stream_batch(self, stream_name: str) -> Iterable[Dict[str, Any]]:
9251002
elapsed,
9261003
row_count / elapsed if elapsed > 0 else 0)
9271004

1005+
def _read_stream_raw(self, stream_name: str) -> Iterable[Tuple[bytes, bytes]]:
1006+
"""Yield raw (schema_bytes, batch_bytes) without decompression.
1007+
1008+
Used when emit_raw_batches is enabled to defer decompression and
1009+
Arrow-to-Python conversion to a downstream DoFn after reshuffling.
1010+
Schema bytes are included in each tuple so each batch is
1011+
self-contained and can be decoded independently.
1012+
"""
1013+
schema_bytes = b''
1014+
batch_count = 0
1015+
t0 = time.time()
1016+
for response in self._storage_client.read_rows(stream_name):
1017+
if not schema_bytes and response.arrow_schema.serialized_schema:
1018+
schema_bytes = bytes(response.arrow_schema.serialized_schema)
1019+
batch_bytes = response.arrow_record_batch.serialized_record_batch
1020+
if batch_bytes and schema_bytes:
1021+
yield (schema_bytes, bytes(batch_bytes))
1022+
batch_count += 1
1023+
elapsed = time.time() - t0
1024+
_LOGGER.info('[Read] raw_read: %d batches in %.2fs', batch_count, elapsed)
1025+
1026+
1027+
class _DecompressArrowBatchesFn(beam.DoFn):
1028+
"""Decompress and convert raw Arrow batches to timestamped row dicts.
1029+
1030+
Receives individual (schema_bytes, batch_bytes) tuples after Reshuffle
1031+
and converts each batch to individual row dicts with event timestamps
1032+
extracted from the change_timestamp column.
1033+
"""
1034+
def __init__(self, change_timestamp_column: str = 'change_timestamp') -> None:
1035+
self._change_timestamp_column = change_timestamp_column
1036+
1037+
def process(self, element: Tuple[bytes, bytes]) -> Iterable[Dict[str, Any]]:
1038+
schema_bytes, batch_bytes = element
1039+
schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes))
1040+
batch = pyarrow.ipc.read_record_batch(
1041+
pyarrow.py_buffer(batch_bytes), schema)
1042+
1043+
rows = batch.to_pylist()
1044+
for row in rows:
1045+
ts = row.get(self._change_timestamp_column)
1046+
if ts is None:
1047+
raise ValueError(
1048+
'Row missing %r column. Row keys: %s' %
1049+
(self._change_timestamp_column, list(row.keys())))
1050+
if isinstance(ts, datetime.datetime):
1051+
ts = Timestamp.from_utc_datetime(ts)
1052+
yield TimestampedValue(row, ts)
1053+
Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc(len(rows))
1054+
9281055

9291056
# =============================================================================
9301057
# Cleanup: _CleanupTempTablesFn
@@ -1038,9 +1165,21 @@ class ReadBigQueryChangeHistory(beam.PTransform):
10381165
on the CHANGES/APPENDS query. Do not include the WHERE keyword.
10391166
Example: ``'status = "active" AND region = "US"'``.
10401167
batch_arrow_read: If True (default), convert Arrow RecordBatches in
1041-
bulk using to_pydict() instead of per-cell .as_py() calls.
1168+
bulk using to_pylist() instead of per-cell .as_py() calls.
10421169
This is 1.5x faster for large tables at the cost of ~2x peak
10431170
memory per RecordBatch. Set to False for minimal memory usage.
1171+
max_split_rounds: Maximum number of recursive SplitReadStream
1172+
rounds. Each round splits every stream at fraction=0.5,
1173+
potentially doubling the stream count (if BQ allows). Default
1174+
1 (one round of splitting). Set 0 to disable splitting
1175+
entirely. Set higher for very large tables where more
1176+
parallelism is needed.
1177+
reshuffle_decompress: If True (default), the Read SDF emits raw
1178+
compressed Arrow batches instead of decoded rows. The batches
1179+
are reshuffled for fan-out and then decoded in a separate DoFn.
1180+
This spreads decompression and Arrow-to-Python conversion CPU
1181+
across more workers. Set to False to decode rows inline within
1182+
the Read SDF.
10441183
"""
10451184
def __init__(
10461185
self,
@@ -1057,7 +1196,9 @@ def __init__(
10571196
change_timestamp_column: str = 'change_timestamp',
10581197
columns: Optional[List[str]] = None,
10591198
row_filter: Optional[str] = None,
1060-
batch_arrow_read: bool = True) -> None:
1199+
batch_arrow_read: bool = True,
1200+
max_split_rounds: int = 1,
1201+
reshuffle_decompress: bool = True) -> None:
10611202
super().__init__()
10621203
if bq_storage is None:
10631204
raise ImportError(
@@ -1091,6 +1232,8 @@ def __init__(
10911232
self._columns = columns
10921233
self._row_filter = row_filter
10931234
self._batch_arrow_read = batch_arrow_read
1235+
self._max_split_rounds = max_split_rounds
1236+
self._reshuffle_decompress = reshuffle_decompress
10941237

10951238
def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection:
10961239
project = self._project
@@ -1170,16 +1313,37 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection:
11701313
row_filter=self._row_filter))
11711314
| 'CommitQueryResults' >> beam.Reshuffle())
11721315

1316+
emit_raw = self._reshuffle_decompress
1317+
1318+
read_sdf = beam.ParDo(
1319+
_ReadStorageStreamsSDF(
1320+
batch_arrow_read=self._batch_arrow_read,
1321+
change_timestamp_column=self._change_timestamp_column,
1322+
max_split_rounds=self._max_split_rounds,
1323+
emit_raw_batches=emit_raw))
1324+
if emit_raw:
1325+
read_sdf = read_sdf.with_output_types(Tuple[bytes, bytes])
1326+
else:
1327+
read_sdf = read_sdf.with_output_types(Dict[str, Any])
1328+
11731329
read_outputs = (
11741330
query_results
1175-
| 'ReadStorageStreams' >> beam.ParDo(
1176-
_ReadStorageStreamsSDF(
1177-
batch_arrow_read=self._batch_arrow_read,
1178-
change_timestamp_column=self._change_timestamp_column)).
1179-
with_outputs(_CLEANUP_TAG, main='rows'))
1331+
| 'ReadStorageStreams' >> read_sdf.with_outputs(
1332+
_CLEANUP_TAG, main='rows'))
11801333

11811334
_ = (
11821335
read_outputs[_CLEANUP_TAG]
11831336
| 'CleanupTempTables' >> beam.ParDo(_CleanupTempTablesFn()))
11841337

1185-
return read_outputs['rows']
1338+
if emit_raw:
1339+
# Reshuffle raw Arrow batches for fan-out, then decompress and
1340+
# convert to timestamped row dicts in a separate DoFn.
1341+
rows = (
1342+
read_outputs['rows']
1343+
| 'ReshuffleForFanout' >> beam.Reshuffle()
1344+
| 'DecompressBatches' >> beam.ParDo(
1345+
_DecompressArrowBatchesFn(
1346+
change_timestamp_column=(self._change_timestamp_column))))
1347+
return rows
1348+
else:
1349+
return read_outputs['rows']

0 commit comments

Comments
 (0)