Skip to content

Commit 889550b

Browse files
authored
Merge pull request #421 from mabel-dev/copilot/fix-21057b26-b99a-47df-8604-20910c37976a
Add support for list of strings in sort_by and dictionary encoding in parquet writer
2 parents 27b4d31 + 989983b commit 889550b

File tree

3 files changed

+203
-7
lines changed

3 files changed

+203
-7
lines changed

mabel/data/writers/internals/blob_writer.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import io
22
import json
33
import threading
4-
from typing import Optional
4+
from typing import List, Optional, Union
55

66
import orjson
77
import orso
@@ -32,13 +32,15 @@ def __init__(
3232
format: str = "parquet",
3333
schema: Optional[RelationSchema] = None,
3434
parquet_row_group_size: int = 5000,
35-
sort_by: Optional[str] = None,
35+
sort_by: Optional[Union[str, List]] = None,
36+
use_dictionary: Optional[Union[bool, List[str]]] = None,
3637
**kwargs,
3738
):
3839
self.format = format
3940
self.maximum_blob_size = blob_size
4041
self.parquet_row_group_size = parquet_row_group_size
4142
self.sort_by = sort_by
43+
self.use_dictionary = use_dictionary
4244

4345
if format not in SUPPORTED_FORMATS_ALGORITHMS:
4446
raise ValueError(
@@ -166,12 +168,18 @@ def commit(self):
166168

167169
# sort the table if sort_by is specified
168170
if self.sort_by:
169-
pytable = pytable.sort_by(self.sort_by)
171+
# Convert list of strings to PyArrow format
172+
sort_spec = self.sort_by
173+
if isinstance(self.sort_by, list) and all(isinstance(item, str) for item in self.sort_by):
174+
# Convert list of strings to list of tuples with default ascending order
175+
sort_spec = [(col, "ascending") for col in self.sort_by]
176+
pytable = pytable.sort_by(sort_spec)
170177

171178
tempfile = io.BytesIO()
172-
pyarrow.parquet.write_table(
173-
pytable, where=tempfile, row_group_size=self.parquet_row_group_size
174-
)
179+
write_kwargs = {"row_group_size": self.parquet_row_group_size}
180+
if self.use_dictionary is not None:
181+
write_kwargs["use_dictionary"] = self.use_dictionary
182+
pyarrow.parquet.write_table(pytable, where=tempfile, **write_kwargs)
175183

176184
tempfile.seek(0)
177185
write_buffer = tempfile.read()

mabel/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Store the version here so:
22
# 1) we don't load dependencies by storing it in __init__.py
33
# 2) we can import it in setup.py for the same reason
4-
__version__ = "0.6.28"
4+
__version__ = "0.6.29"
55

66
# nodoc - don't add to the documentation wiki

tests/test_writer_parquet_features.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,194 @@ def test_parquet_default_row_group_size():
148148
shutil.rmtree("_temp_default", ignore_errors=True)
149149

150150

151+
def test_parquet_sorting_list_of_strings():
152+
"""Test that sort_by parameter can accept a list of strings"""
153+
shutil.rmtree("_temp_sort_list", ignore_errors=True)
154+
155+
# Create a writer with sorting by list of strings
156+
w = BatchWriter(
157+
inner_writer=DiskWriter,
158+
dataset="_temp_sort_list",
159+
format="parquet",
160+
date=datetime.datetime.utcnow().date(),
161+
schema=[
162+
{"name": "id", "type": "INTEGER"},
163+
{"name": "category", "type": "VARCHAR"},
164+
{"name": "value", "type": "VARCHAR"}
165+
],
166+
sort_by=["category", "id"], # Sort by category first, then id
167+
parquet_row_group_size=5000,
168+
)
169+
170+
# Write records in random order
171+
records_to_write = [
172+
{"id": 3, "category": "B", "value": "value_3"},
173+
{"id": 1, "category": "A", "value": "value_1"},
174+
{"id": 4, "category": "B", "value": "value_4"},
175+
{"id": 2, "category": "A", "value": "value_2"},
176+
{"id": 5, "category": "C", "value": "value_5"},
177+
]
178+
179+
for record in records_to_write:
180+
w.append(record)
181+
182+
w.finalize()
183+
184+
# Read back and verify the data is sorted by category, then id
185+
r = Reader(inner_reader=DiskReader, dataset="_temp_sort_list")
186+
records = list(r)
187+
188+
assert len(records) == 5, f"Expected 5 records, got {len(records)}"
189+
190+
# Check that records are sorted by category first, then by id
191+
expected_order = [
192+
{"id": 1, "category": "A", "value": "value_1"},
193+
{"id": 2, "category": "A", "value": "value_2"},
194+
{"id": 3, "category": "B", "value": "value_3"},
195+
{"id": 4, "category": "B", "value": "value_4"},
196+
{"id": 5, "category": "C", "value": "value_5"},
197+
]
198+
199+
for i, record in enumerate(records):
200+
assert record["id"] == expected_order[i]["id"], f"Record {i} id mismatch: {record['id']} != {expected_order[i]['id']}"
201+
assert record["category"] == expected_order[i]["category"], f"Record {i} category mismatch"
202+
203+
shutil.rmtree("_temp_sort_list", ignore_errors=True)
204+
205+
206+
def test_parquet_sorting_single_column_list():
207+
"""Test that sort_by parameter can accept a list with a single string"""
208+
shutil.rmtree("_temp_sort_single_list", ignore_errors=True)
209+
210+
# Create a writer with sorting by a list containing a single column
211+
w = BatchWriter(
212+
inner_writer=DiskWriter,
213+
dataset="_temp_sort_single_list",
214+
format="parquet",
215+
date=datetime.datetime.utcnow().date(),
216+
schema=[{"name": "id", "type": "INTEGER"}, {"name": "value", "type": "VARCHAR"}],
217+
sort_by=["id"], # Sort by id column as a list
218+
parquet_row_group_size=5000,
219+
)
220+
221+
# Write records in reverse order
222+
for i in range(10, 0, -1):
223+
w.append({"id": i, "value": f"value_{i}"})
224+
225+
w.finalize()
226+
227+
# Read back and verify the data is sorted
228+
r = Reader(inner_reader=DiskReader, dataset="_temp_sort_single_list")
229+
records = list(r)
230+
231+
assert len(records) == 10, f"Expected 10 records, got {len(records)}"
232+
233+
# Check that records are sorted by id
234+
ids = [record["id"] for record in records]
235+
assert ids == list(range(1, 11)), f"Records are not sorted correctly: {ids}"
236+
237+
shutil.rmtree("_temp_sort_single_list", ignore_errors=True)
238+
239+
240+
def test_parquet_dictionary_encoding_all():
241+
"""Test that use_dictionary parameter can be set to True for all columns"""
242+
shutil.rmtree("_temp_dict_all", ignore_errors=True)
243+
244+
w = BatchWriter(
245+
inner_writer=DiskWriter,
246+
dataset="_temp_dict_all",
247+
format="parquet",
248+
date=datetime.datetime.utcnow().date(),
249+
schema=[
250+
{"name": "id", "type": "INTEGER"},
251+
{"name": "category", "type": "VARCHAR"}
252+
],
253+
use_dictionary=True, # Enable dictionary encoding for all columns
254+
)
255+
256+
# Write records with repeated category values (good for dictionary encoding)
257+
for i in range(100):
258+
w.append({"id": i, "category": f"category_{i % 5}"})
259+
260+
w.finalize()
261+
262+
# Read back and verify data
263+
r = Reader(inner_reader=DiskReader, dataset="_temp_dict_all")
264+
records = list(r)
265+
assert len(records) == 100, f"Expected 100 records, got {len(records)}"
266+
267+
shutil.rmtree("_temp_dict_all", ignore_errors=True)
268+
269+
270+
def test_parquet_dictionary_encoding_disabled():
271+
"""Test that use_dictionary parameter can be set to False to disable dictionary encoding"""
272+
shutil.rmtree("_temp_dict_disabled", ignore_errors=True)
273+
274+
w = BatchWriter(
275+
inner_writer=DiskWriter,
276+
dataset="_temp_dict_disabled",
277+
format="parquet",
278+
date=datetime.datetime.utcnow().date(),
279+
schema=[
280+
{"name": "id", "type": "INTEGER"},
281+
{"name": "category", "type": "VARCHAR"}
282+
],
283+
use_dictionary=False, # Disable dictionary encoding
284+
)
285+
286+
# Write records
287+
for i in range(100):
288+
w.append({"id": i, "category": f"category_{i % 5}"})
289+
290+
w.finalize()
291+
292+
# Read back and verify data
293+
r = Reader(inner_reader=DiskReader, dataset="_temp_dict_disabled")
294+
records = list(r)
295+
assert len(records) == 100, f"Expected 100 records, got {len(records)}"
296+
297+
shutil.rmtree("_temp_dict_disabled", ignore_errors=True)
298+
299+
300+
def test_parquet_dictionary_encoding_specific_columns():
301+
"""Test that use_dictionary parameter can specify specific columns for dictionary encoding"""
302+
shutil.rmtree("_temp_dict_specific", ignore_errors=True)
303+
304+
w = BatchWriter(
305+
inner_writer=DiskWriter,
306+
dataset="_temp_dict_specific",
307+
format="parquet",
308+
date=datetime.datetime.utcnow().date(),
309+
schema=[
310+
{"name": "id", "type": "INTEGER"},
311+
{"name": "category", "type": "VARCHAR"},
312+
{"name": "value", "type": "VARCHAR"}
313+
],
314+
use_dictionary=["category"], # Only encode 'category' column with dictionary
315+
)
316+
317+
# Write records with repeated category values
318+
for i in range(100):
319+
w.append({
320+
"id": i,
321+
"category": f"category_{i % 5}",
322+
"value": f"unique_value_{i}"
323+
})
324+
325+
w.finalize()
326+
327+
# Read back and verify data
328+
r = Reader(inner_reader=DiskReader, dataset="_temp_dict_specific")
329+
records = list(r)
330+
assert len(records) == 100, f"Expected 100 records, got {len(records)}"
331+
332+
# Verify the data is correct
333+
assert records[0]["category"] == "category_0"
334+
assert records[50]["category"] == "category_0"
335+
336+
shutil.rmtree("_temp_dict_specific", ignore_errors=True)
337+
338+
151339
if __name__ == "__main__": # pragma: no cover
152340
from tests.helpers.runner import run_tests
153341

0 commit comments

Comments
 (0)