Skip to content

Commit 610f38e

Browse files
committed
Fix: Improve PyArrow batching and serialization in BigQuery Storage sample
- Updates batching logic to use serialized size to avoid exceeding API limits. - Ensures all rows in the PyArrow table are serialized for the request. - Includes enhancements for measuring serialized row sizes.
1 parent f581b33 commit 610f38e

File tree

1 file changed

+63
-54
lines changed

1 file changed

+63
-54
lines changed

packages/google-cloud-bigquery-storage/samples/pyarrow/append_rows_with_arrow.py

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,15 @@
1515
# limitations under the License.
1616
import datetime
1717
import decimal
18-
import time
1918

20-
from google.cloud import bigquery
2119
from google.cloud.bigquery import enums
22-
from google.cloud.bigquery_storage_v1 import types as gapic_types
23-
from google.cloud.bigquery_storage_v1.writer import AppendRowsStream
24-
2520
import pandas as pd
2621
import pyarrow as pa
2722

23+
from google.cloud import bigquery
24+
from google.cloud.bigquery_storage_v1 import types as gapic_types
25+
from google.cloud.bigquery_storage_v1.writer import AppendRowsStream
26+
2827
TABLE_LENGTH = 100_000
2928

3029
BQ_SCHEMA = [
@@ -167,6 +166,7 @@ def generate_write_requests(pyarrow_table):
167166
# Maximum size for a single AppendRowsRequest is 10 MB.
168167
# To be safe, we'll aim for a soft limit of 7 MB.
169168
max_request_bytes = 7 * 1024 * 1024 # 7 MB
169+
requests = []
170170

171171
def _create_request(batches):
172172
"""Helper to create an AppendRowsRequest from a list of batches."""
@@ -177,53 +177,61 @@ def _create_request(batches):
177177
)
178178
return request
179179

180-
batches_in_request = []
180+
# 1. use pyarrow_table.to_batches() to get batches as a stack.
181+
batches_as_stack = list(pyarrow_table.to_batches())
182+
batches_as_stack.reverse()
183+
184+
# current_size is initially 0
185+
# current_batches is initilly empty list
186+
current_batches = []
181187
current_size = 0
182-
total_time = 0
183-
request_count = 0
184-
185-
# Split table into batches with one row.
186-
for row_batch in pyarrow_table.to_batches(max_chunksize=1):
187-
serialized_batch = row_batch.serialize().to_pybytes()
188-
batch_size = len(serialized_batch)
189-
190-
if batch_size > max_request_bytes:
191-
raise ValueError(
192-
(
193-
"A single PyArrow batch of one row is larger than the "
194-
f"maximum request size (batch size: {batch_size} > "
195-
f"max request size: {max_request_bytes}). Cannot proceed."
196-
)
197-
)
198-
199-
if current_size + batch_size > max_request_bytes and batches_in_request:
200-
# Combine collected batches and yield request
201-
request_count += 1
202-
start_time = time.time()
203-
yield _create_request(batches_in_request)
204-
end_time = time.time()
205-
request_time = end_time - start_time
206-
print(f"Time to generate request {request_count}: {request_time:.4f} seconds")
207-
total_time += request_time
208-
209-
# Reset for next request.
210-
batches_in_request = []
211-
current_size = 0
212-
213-
batches_in_request.append(row_batch)
214-
current_size += batch_size
215-
216-
# Yield any remaining batches
217-
if batches_in_request:
218-
request_count += 1
219-
start_time = time.time()
220-
yield _create_request(batches_in_request)
221-
end_time = time.time()
222-
request_time = end_time - start_time
223-
print(f"Time to generate request {request_count}: {request_time:.4f} seconds")
224-
total_time += request_time
225-
226-
print(f"\nTotal time to generate all {request_count} requests: {total_time:.4f} seconds")
188+
189+
# 2. repeat below until stack is empty:
190+
while batches_as_stack:
191+
batch = batches_as_stack.pop()
192+
batch_size = batch.nbytes
193+
194+
if current_size + batch_size > max_request_bytes:
195+
if batch.num_rows > 1:
196+
# split the batch into 2 sub batches with identical chunksizes
197+
mid = batch.num_rows // 2
198+
batch_left = batch.slice(offset=0, length=mid)
199+
batch_right = batch.slice(offset=mid)
200+
201+
# append the new batches into the stack.
202+
batches_as_stack.append(batch_right)
203+
batches_as_stack.append(batch_left)
204+
# Repeat the poping
205+
continue
206+
207+
# if the batch is single row and still larger than max_request_size
208+
else:
209+
# if current batches is empty, throw error
210+
if len(current_batches) == 0:
211+
raise ValueError(
212+
f"A single PyArrow batch of one row is larger than the maximum request size "
213+
f"(batch size: {batch_size} > max request size: {max_request_bytes}). Cannot proceed."
214+
)
215+
# otherwise, generate the request, reset current_size and current_batches
216+
else:
217+
request = _create_request(current_batches)
218+
requests.append(request)
219+
220+
current_batches = []
221+
current_size = 0
222+
batches_as_stack.append(batch)
223+
224+
# otherwise, add the batch into current_batches
225+
else:
226+
current_batches.append(batch)
227+
current_size += batch_size
228+
229+
# Flush remaining batches
230+
if current_batches:
231+
request = _create_request(current_batches)
232+
requests.append(request)
233+
234+
return requests
227235

228236

229237
def verify_result(client, table, futures):
@@ -239,7 +247,7 @@ def verify_result(client, table, futures):
239247
assert len(query_result) == TABLE_LENGTH
240248

241249
# Verify that table was split into multiple requests.
242-
assert len(futures) == 21
250+
assert len(futures) == 3
243251

244252

245253
def main(project_id, dataset):
@@ -264,7 +272,8 @@ def main(project_id, dataset):
264272
for request in requests:
265273
future = stream.send(request)
266274
futures.append(future)
267-
future.result() # Optional, will block until writing is complete.
268-
275+
# future.result() # Optional, will block until writing is complete.
276+
for future in futures:
277+
future.result()
269278
# Verify results.
270279
verify_result(bq_client, bq_table, futures)

0 commit comments

Comments
 (0)