Skip to content

Commit 45d4cee

Browse files
authored
docs(samples): Update BigQuery Storage Arrow samples batching logic (#14961)
Bases batching on size rather than row count to avoid exceeding an internal 10MB limit. Also removes an obsolete assertion in the test. Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/google-cloud-python/issues) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 1e3be63 commit 45d4cee

File tree

2 files changed

+153
-20
lines changed

2 files changed

+153
-20
lines changed

packages/google-cloud-bigquery-storage/samples/pyarrow/append_rows_with_arrow.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
import datetime
1717
import decimal
1818

19-
from google.cloud import bigquery
2019
from google.cloud.bigquery import enums
21-
from google.cloud.bigquery_storage_v1 import types as gapic_types
22-
from google.cloud.bigquery_storage_v1.writer import AppendRowsStream
2320
import pandas as pd
24-
2521
import pyarrow as pa
2622

23+
from google.cloud import bigquery
24+
from google.cloud.bigquery_storage_v1 import types as gapic_types
25+
from google.cloud.bigquery_storage_v1.writer import AppendRowsStream
26+
2727
TABLE_LENGTH = 100_000
2828

2929
BQ_SCHEMA = [
@@ -100,7 +100,10 @@ def make_table(project_id, dataset_id, bq_client):
100100

101101

102102
def create_stream(bqstorage_write_client, table):
103-
stream_name = f"projects/{table.project}/datasets/{table.dataset_id}/tables/{table.table_id}/_default"
103+
stream_name = (
104+
f"projects/{table.project}/datasets/{table.dataset_id}/"
105+
f"tables/{table.table_id}/_default"
106+
)
104107
request_template = gapic_types.AppendRowsRequest()
105108
request_template.write_stream = stream_name
106109

@@ -160,18 +163,63 @@ def generate_pyarrow_table(num_rows=TABLE_LENGTH):
160163

161164

162165
def generate_write_requests(pyarrow_table):
163-
# Determine max_chunksize of the record batches. Because max size of
164-
# AppendRowsRequest is 10 MB, we need to split the table if it's too big.
165-
# See: https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1#appendrowsrequest
166-
max_request_bytes = 10 * 2**20 # 10 MB
167-
chunk_num = int(pyarrow_table.nbytes / max_request_bytes) + 1
168-
chunk_size = int(pyarrow_table.num_rows / chunk_num)
169-
170-
# Construct request(s).
171-
for batch in pyarrow_table.to_batches(max_chunksize=chunk_size):
166+
# Maximum size for a single AppendRowsRequest is 10 MB.
167+
# To be safe, we'll aim for a soft limit of 7 MB.
168+
max_request_bytes = 7 * 1024 * 1024 # 7 MB
169+
170+
def _create_request(batches):
171+
"""Helper to create an AppendRowsRequest from a list of batches."""
172+
combined_table = pa.Table.from_batches(batches)
172173
request = gapic_types.AppendRowsRequest()
173-
request.arrow_rows.rows.serialized_record_batch = batch.serialize().to_pybytes()
174-
yield request
174+
request.arrow_rows.rows.serialized_record_batch = (
175+
combined_table.combine_chunks().to_batches()[0].serialize().to_pybytes()
176+
)
177+
return request
178+
179+
batches = pyarrow_table.to_batches()
180+
181+
current_batches = []
182+
current_size = 0
183+
184+
while batches:
185+
batch = batches.pop()
186+
batch_size = batch.nbytes
187+
188+
# If the batch is larger than max_request_bytes, split it into 2 sub batches.
189+
if batch_size > max_request_bytes:
190+
if batch.num_rows > 1:
191+
# Split the batch into 2 sub batches with identical chunksizes
192+
mid = batch.num_rows // 2
193+
batch_left = batch.slice(offset=0, length=mid)
194+
batch_right = batch.slice(offset=mid)
195+
196+
# Append the new batches into the stack and continue poping.
197+
batches.append(batch_left)
198+
batches.append(batch_right)
199+
continue
200+
# If the batch is single row and still larger than max_request_bytes
201+
else:
202+
raise ValueError(
203+
f"A single PyArrow batch of one row is larger than the maximum request size "
204+
f"(batch size: {batch_size} > max request size: {max_request_bytes}). Cannot proceed."
205+
)
206+
# The current batches are ok to form a request when next batch will exceed the max_request_bytes.
207+
if current_size + batch_size > max_request_bytes:
208+
# Current batches can't be empty otherwise it will suffice batch_size > max_request_bytes above.
209+
yield _create_request(current_batches)
210+
211+
current_batches = []
212+
current_size = 0
213+
batches.append(batch)
214+
215+
# Otherwise, add the batch into current_batches
216+
else:
217+
current_batches.append(batch)
218+
current_size += batch_size
219+
220+
# Flush remaining batches
221+
if current_batches:
222+
yield _create_request(current_batches)
175223

176224

177225
def verify_result(client, table, futures):
@@ -181,14 +229,13 @@ def verify_result(client, table, futures):
181229
assert bq_table.schema == BQ_SCHEMA
182230

183231
# Verify table size.
184-
query = client.query(f"SELECT COUNT(1) FROM `{bq_table}`;")
232+
query = client.query(f"SELECT DISTINCT int64_col FROM `{bq_table}`;")
185233
query_result = query.result().to_dataframe()
186234

187-
# There might be extra rows due to retries.
188-
assert query_result.iloc[0, 0] >= TABLE_LENGTH
235+
assert len(query_result) == TABLE_LENGTH
189236

190237
# Verify that table was split into multiple requests.
191-
assert len(futures) == 2
238+
assert len(futures) == 4
192239

193240

194241
def main(project_id, dataset):
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
17+
import pyarrow as pa
18+
import pytest
19+
20+
from . import append_rows_with_arrow
21+
22+
23+
def create_table_with_batches(num_batches, rows_per_batch):
24+
# Generate a small table to get a valid batch
25+
small_table = append_rows_with_arrow.generate_pyarrow_table(rows_per_batch)
26+
# Ensure we get exactly one batch for the small table
27+
batches = small_table.to_batches()
28+
assert len(batches) == 1
29+
batch = batches[0]
30+
31+
# Replicate the batch
32+
all_batches = [batch] * num_batches
33+
return pa.Table.from_batches(all_batches)
34+
35+
36+
# Test generate_write_requests with different numbers of batches in the input table.
37+
# The total rows in the generated table is constantly 1000000.
38+
@pytest.mark.parametrize(
39+
"num_batches, rows_per_batch, expected_requests",
40+
[
41+
(1, 1000000, 32),
42+
(10, 100000, 40),
43+
(100, 10000, 34),
44+
(1000, 1000, 26),
45+
(10000, 100, 26),
46+
(100000, 10, 26),
47+
(1000000, 1, 26),
48+
],
49+
)
50+
def test_generate_write_requests_varying_batches(
51+
num_batches, rows_per_batch, expected_requests
52+
):
53+
"""Test generate_write_requests with different numbers of batches in the input table."""
54+
# Create a table that returns `num_batches` when to_batches() is called.
55+
table = create_table_with_batches(num_batches, rows_per_batch)
56+
57+
# Verify our setup is correct
58+
assert len(table.to_batches()) == num_batches
59+
60+
# Generate requests
61+
start_time = time.perf_counter()
62+
requests = list(append_rows_with_arrow.generate_write_requests(table))
63+
end_time = time.perf_counter()
64+
print(
65+
f"\nTime used to generate requests for {num_batches} batches: {end_time - start_time:.4f} seconds"
66+
)
67+
68+
assert len(requests) == expected_requests
69+
70+
# Verify total rows in requests matches total rows in table
71+
total_rows_processed = 0
72+
for request in requests:
73+
# Deserialize the batch from the request to count rows
74+
serialized_batch = request.arrow_rows.rows.serialized_record_batch
75+
76+
# Verify the batch size is less than 7MB
77+
assert len(serialized_batch) <= 7 * 1024 * 1024
78+
79+
# We need a schema to read the batch. The schema is PYARROW_SCHEMA.
80+
batch = pa.ipc.read_record_batch(
81+
serialized_batch, append_rows_with_arrow.PYARROW_SCHEMA
82+
)
83+
total_rows_processed += batch.num_rows
84+
85+
expected_rows = num_batches * rows_per_batch
86+
assert total_rows_processed == expected_rows

0 commit comments

Comments
 (0)