Skip to content

Commit 2464fd3

Browse files
committed
enhance unit test for write request generation
1 parent ec49ec5 commit 2464fd3

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
17+
import pyarrow as pa
18+
import pytest
19+
20+
from . import append_rows_with_arrow
21+
22+
23+
def create_table_with_batches(num_batches, rows_per_batch):
24+
# Generate a small table to get a valid batch
25+
small_table = append_rows_with_arrow.generate_pyarrow_table(rows_per_batch)
26+
# Ensure we get exactly one batch for the small table
27+
batches = small_table.to_batches()
28+
assert len(batches) == 1
29+
batch = batches[0]
30+
31+
# Replicate the batch
32+
all_batches = [batch] * num_batches
33+
return pa.Table.from_batches(all_batches)
34+
35+
36+
# Test generate_write_requests with different numbers of batches in the input table.
37+
# The total rows in the generated table is constantly 1000000.
38+
@pytest.mark.parametrize(
39+
"num_batches, rows_per_batch",
40+
[
41+
(1, 1000000),
42+
(10, 100000),
43+
(100, 10000),
44+
(1000, 1000),
45+
(10000, 100),
46+
(100000, 10),
47+
(1000000, 1),
48+
],
49+
)
50+
def test_generate_write_requests_varying_batches(num_batches, rows_per_batch):
51+
"""Test generate_write_requests with different numbers of batches in the input table."""
52+
# Create a table that returns `num_batches` when to_batches() is called.
53+
table = create_table_with_batches(num_batches, rows_per_batch)
54+
55+
# Verify our setup is correct
56+
assert len(table.to_batches()) == num_batches
57+
58+
# Generate requests
59+
start_time = time.perf_counter()
60+
requests = list(append_rows_with_arrow.generate_write_requests(table))
61+
end_time = time.perf_counter()
62+
print(
63+
f"\nTime used to generate requests for {num_batches} batches: {end_time - start_time:.4f} seconds"
64+
)
65+
66+
# We expect the requests to be aggregated until 7MB.
67+
# Since the row number is constant, the number of requests should be deterministic.
68+
assert len(requests) == 26
69+
70+
# Verify total rows in requests matches total rows in table
71+
total_rows_processed = 0
72+
for request in requests:
73+
# Deserialize the batch from the request to count rows
74+
serialized_batch = request.arrow_rows.rows.serialized_record_batch
75+
# We need a schema to read the batch. The schema is PYARROW_SCHEMA.
76+
batch = pa.ipc.read_record_batch(
77+
serialized_batch, append_rows_with_arrow.PYARROW_SCHEMA
78+
)
79+
total_rows_processed += batch.num_rows
80+
81+
expected_rows = num_batches * rows_per_batch
82+
assert total_rows_processed == expected_rows

0 commit comments

Comments
 (0)