Skip to content

Commit c3b30b3

Browse files
Copilotjoocer
andcommitted
Add parquet row group size and sorting features
Co-authored-by: joocer <1688479+joocer@users.noreply.github.com>
1 parent c42000d commit c3b30b3

File tree

2 files changed

+165
-1
lines changed

2 files changed

+165
-1
lines changed

mabel/data/writers/internals/blob_writer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,14 @@ def __init__(
3131
blob_size: int = BLOB_SIZE,
3232
format: str = "parquet",
3333
schema: Optional[RelationSchema] = None,
34+
parquet_row_group_size: int = 5000,
35+
sort_by: Optional[str] = None,
3436
**kwargs,
3537
):
3638
self.format = format
3739
self.maximum_blob_size = blob_size
40+
self.parquet_row_group_size = parquet_row_group_size
41+
self.sort_by = sort_by
3842

3943
if format not in SUPPORTED_FORMATS_ALGORITHMS:
4044
raise ValueError(
@@ -158,8 +162,14 @@ def commit(self):
158162
if self.schema:
159163
pytable = self._normalize_arrow_schema(pytable, self.schema)
160164

165+
# sort the table if sort_by is specified
166+
if self.sort_by:
167+
pytable = pytable.sort_by(self.sort_by)
168+
161169
tempfile = io.BytesIO()
162-
pyarrow.parquet.write_table(pytable, where=tempfile)
170+
pyarrow.parquet.write_table(
171+
pytable, where=tempfile, row_group_size=self.parquet_row_group_size
172+
)
163173

164174
tempfile.seek(0)
165175
write_buffer = tempfile.read()
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import shutil
2+
import os
3+
import sys
4+
import datetime
5+
6+
sys.path.insert(1, os.path.join(sys.path[0], ".."))
7+
8+
from mabel.adapters.disk import DiskReader, DiskWriter
9+
from mabel.data import BatchWriter, Reader
10+
11+
12+
def test_parquet_row_group_size():
13+
"""Test that parquet_row_group_size parameter is passed and used correctly"""
14+
shutil.rmtree("_temp_rowgroup", ignore_errors=True)
15+
16+
# Create a writer with custom row group size
17+
w = BatchWriter(
18+
inner_writer=DiskWriter,
19+
dataset="_temp_rowgroup",
20+
format="parquet",
21+
date=datetime.datetime.utcnow().date(),
22+
schema=[{"name": "id", "type": "INTEGER"}, {"name": "value", "type": "VARCHAR"}],
23+
parquet_row_group_size=100, # Small row group for testing
24+
)
25+
26+
# Write 1000 records
27+
for i in range(1000):
28+
w.append({"id": i, "value": f"value_{i}"})
29+
30+
w.finalize()
31+
32+
# Read back and verify we can read the data
33+
r = Reader(inner_reader=DiskReader, dataset="_temp_rowgroup")
34+
records = list(r)
35+
assert len(records) == 1000, f"Expected 1000 records, got {len(records)}"
36+
37+
# Verify the parquet file has multiple row groups
38+
import glob
39+
import pyarrow.parquet as pq
40+
41+
parquet_files = glob.glob("_temp_rowgroup/**/*.parquet", recursive=True)
42+
assert len(parquet_files) > 0, "No parquet files found"
43+
44+
# Check row groups in the first file
45+
parquet_file = pq.ParquetFile(parquet_files[0])
46+
num_row_groups = parquet_file.num_row_groups
47+
48+
# With 1000 records and row_group_size=100, we should have multiple row groups
49+
# (exact number depends on how records are distributed across blobs)
50+
assert num_row_groups > 0, f"Expected at least 1 row group, got {num_row_groups}"
51+
52+
shutil.rmtree("_temp_rowgroup", ignore_errors=True)
53+
54+
55+
def test_parquet_sorting():
56+
"""Test that sort_by parameter sorts records correctly"""
57+
shutil.rmtree("_temp_sort", ignore_errors=True)
58+
59+
# Create a writer with sorting
60+
w = BatchWriter(
61+
inner_writer=DiskWriter,
62+
dataset="_temp_sort",
63+
format="parquet",
64+
date=datetime.datetime.utcnow().date(),
65+
schema=[{"name": "id", "type": "INTEGER"}, {"name": "value", "type": "VARCHAR"}],
66+
sort_by="id", # Sort by id column
67+
parquet_row_group_size=5000,
68+
)
69+
70+
# Write records in reverse order
71+
for i in range(100, 0, -1):
72+
w.append({"id": i, "value": f"value_{i}"})
73+
74+
w.finalize()
75+
76+
# Read back and verify the data is sorted
77+
r = Reader(inner_reader=DiskReader, dataset="_temp_sort")
78+
records = list(r)
79+
80+
assert len(records) == 100, f"Expected 100 records, got {len(records)}"
81+
82+
# Check that records are sorted by id
83+
ids = [record["id"] for record in records]
84+
assert ids == list(range(1, 101)), f"Records are not sorted correctly: {ids[:10]}..."
85+
86+
shutil.rmtree("_temp_sort", ignore_errors=True)
87+
88+
89+
def test_parquet_sorting_descending():
90+
"""Test that sort_by parameter can sort in descending order"""
91+
shutil.rmtree("_temp_sort_desc", ignore_errors=True)
92+
93+
# Create a writer with descending sorting
94+
w = BatchWriter(
95+
inner_writer=DiskWriter,
96+
dataset="_temp_sort_desc",
97+
format="parquet",
98+
date=datetime.datetime.utcnow().date(),
99+
schema=[{"name": "id", "type": "INTEGER"}, {"name": "value", "type": "VARCHAR"}],
100+
sort_by=[("id", "descending")], # Sort by id in descending order
101+
parquet_row_group_size=5000,
102+
)
103+
104+
# Write records in random order
105+
for i in [50, 10, 90, 30, 70, 20, 80, 40, 60, 100]:
106+
w.append({"id": i, "value": f"value_{i}"})
107+
108+
w.finalize()
109+
110+
# Read back and verify the data is sorted in descending order
111+
r = Reader(inner_reader=DiskReader, dataset="_temp_sort_desc")
112+
records = list(r)
113+
114+
assert len(records) == 10, f"Expected 10 records, got {len(records)}"
115+
116+
# Check that records are sorted by id in descending order
117+
ids = [record["id"] for record in records]
118+
expected_ids = [100, 90, 80, 70, 60, 50, 40, 30, 20, 10]
119+
assert ids == expected_ids, f"Records are not sorted correctly in descending order: {ids}"
120+
121+
shutil.rmtree("_temp_sort_desc", ignore_errors=True)
122+
123+
124+
def test_parquet_default_row_group_size():
125+
"""Test that default row group size is 5000"""
126+
shutil.rmtree("_temp_default", ignore_errors=True)
127+
128+
# Create a writer without specifying row group size
129+
w = BatchWriter(
130+
inner_writer=DiskWriter,
131+
dataset="_temp_default",
132+
format="parquet",
133+
date=datetime.datetime.utcnow().date(),
134+
schema=[{"name": "id", "type": "INTEGER"}, {"name": "value", "type": "VARCHAR"}],
135+
)
136+
137+
# Write some records
138+
for i in range(100):
139+
w.append({"id": i, "value": f"value_{i}"})
140+
141+
w.finalize()
142+
143+
# Read back and verify we can read the data
144+
r = Reader(inner_reader=DiskReader, dataset="_temp_default")
145+
records = list(r)
146+
assert len(records) == 100, f"Expected 100 records, got {len(records)}"
147+
148+
shutil.rmtree("_temp_default", ignore_errors=True)
149+
150+
151+
if __name__ == "__main__": # pragma: no cover
152+
from tests.helpers.runner import run_tests
153+
154+
run_tests()

0 commit comments

Comments
 (0)