Skip to content

Commit 8867ecd

Browse files
committed
fix: logic update in generate_write_requests and updated tests
1 parent fa73c4b commit 8867ecd

File tree

2 files changed

+31
-28
lines changed

2 files changed

+31
-28
lines changed

packages/google-cloud-bigquery-storage/samples/pyarrow/append_rows_with_arrow.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -185,33 +185,32 @@ def _create_request(batches):
185185
batch = batches.pop()
186186
batch_size = batch.nbytes
187187

188-
if current_size + batch_size > max_request_bytes:
188+
# If the batch is larger than max_request_bytes, split it into 2 sub batches.
189+
if batch_size > max_request_bytes:
189190
if batch.num_rows > 1:
190191
# Split the batch into 2 sub batches with identical chunksizes
191192
mid = batch.num_rows // 2
192193
batch_left = batch.slice(offset=0, length=mid)
193194
batch_right = batch.slice(offset=mid)
194195

195196
# Append the new batches into the stack and continue poping.
196-
batches.append(batch_right)
197197
batches.append(batch_left)
198+
batches.append(batch_right)
198199
continue
199-
200200
# If the batch is single row and still larger than max_request_bytes
201201
else:
202-
# If current batches is empty, throw error
203-
if len(current_batches) == 0:
204-
raise ValueError(
205-
f"A single PyArrow batch of one row is larger than the maximum request size "
206-
f"(batch size: {batch_size} > max request size: {max_request_bytes}). Cannot proceed."
207-
)
208-
# Otherwise, generate the request, reset current_size and current_batches
209-
else:
210-
yield _create_request(current_batches)
211-
212-
current_batches = []
213-
current_size = 0
214-
batches.append(batch)
202+
raise ValueError(
203+
f"A single PyArrow batch of one row is larger than the maximum request size "
204+
f"(batch size: {batch_size} > max request size: {max_request_bytes}). Cannot proceed."
205+
)
206+
# The current batches are ok to form a request when next batch will exceed the max_request_bytes.
207+
if current_size + batch_size > max_request_bytes:
208+
# Current batches can't be empty otherwise it will suffice batch_size > max_request_bytes above.
209+
yield _create_request(current_batches)
210+
211+
current_batches = []
212+
current_size = 0
213+
batches.append(batch)
215214

216215
# Otherwise, add the batch into current_batches
217216
else:

packages/google-cloud-bigquery-storage/samples/pyarrow/test_generate_write_requests.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,20 @@ def create_table_with_batches(num_batches, rows_per_batch):
3636
# Test generate_write_requests with different numbers of batches in the input table.
3737
# The total rows in the generated table is constantly 1000000.
3838
@pytest.mark.parametrize(
39-
"num_batches, rows_per_batch",
39+
"num_batches, rows_per_batch, expected_requests",
4040
[
41-
(1, 1000000),
42-
(10, 100000),
43-
(100, 10000),
44-
(1000, 1000),
45-
(10000, 100),
46-
(100000, 10),
47-
(1000000, 1),
41+
(1, 1000000, 32),
42+
(10, 100000, 40),
43+
(100, 10000, 34),
44+
(1000, 1000, 26),
45+
(10000, 100, 26),
46+
(100000, 10, 26),
47+
(1000000, 1, 26),
4848
],
4949
)
50-
def test_generate_write_requests_varying_batches(num_batches, rows_per_batch):
50+
def test_generate_write_requests_varying_batches(
51+
num_batches, rows_per_batch, expected_requests
52+
):
5153
"""Test generate_write_requests with different numbers of batches in the input table."""
5254
# Create a table that returns `num_batches` when to_batches() is called.
5355
table = create_table_with_batches(num_batches, rows_per_batch)
@@ -63,15 +65,17 @@ def test_generate_write_requests_varying_batches(num_batches, rows_per_batch):
6365
f"\nTime used to generate requests for {num_batches} batches: {end_time - start_time:.4f} seconds"
6466
)
6567

66-
# We expect the requests to be aggregated until 7MB.
67-
# Since the row number is constant, the number of requests should be deterministic.
68-
assert len(requests) == 26
68+
assert len(requests) == expected_requests
6969

7070
# Verify total rows in requests matches total rows in table
7171
total_rows_processed = 0
7272
for request in requests:
7373
# Deserialize the batch from the request to count rows
7474
serialized_batch = request.arrow_rows.rows.serialized_record_batch
75+
76+
# Verify the batch size is less than 7MB
77+
assert len(serialized_batch) <= 7 * 1024 * 1024
78+
7579
# We need a schema to read the batch. The schema is PYARROW_SCHEMA.
7680
batch = pa.ipc.read_record_batch(
7781
serialized_batch, append_rows_with_arrow.PYARROW_SCHEMA

0 commit comments

Comments
 (0)