Skip to content

Commit 26fbade

Browse files
Copilotjoocer
andcommitted
Add use_dictionary parameter for parquet dictionary encoding
Co-authored-by: joocer <1688479+joocer@users.noreply.github.com>
1 parent ca54e77 commit 26fbade

File tree

2 files changed

+105
-3
lines changed

2 files changed

+105
-3
lines changed

mabel/data/writers/internals/blob_writer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,14 @@ def __init__(
3333
schema: Optional[RelationSchema] = None,
3434
parquet_row_group_size: int = 5000,
3535
sort_by: Optional[Union[str, List]] = None,
36+
use_dictionary: Optional[Union[bool, List[str]]] = None,
3637
**kwargs,
3738
):
3839
self.format = format
3940
self.maximum_blob_size = blob_size
4041
self.parquet_row_group_size = parquet_row_group_size
4142
self.sort_by = sort_by
43+
self.use_dictionary = use_dictionary
4244

4345
if format not in SUPPORTED_FORMATS_ALGORITHMS:
4446
raise ValueError(
@@ -172,9 +174,10 @@ def commit(self):
172174
pytable = pytable.sort_by(sort_spec)
173175

174176
tempfile = io.BytesIO()
175-
pyarrow.parquet.write_table(
176-
pytable, where=tempfile, row_group_size=self.parquet_row_group_size
177-
)
177+
write_kwargs = {"row_group_size": self.parquet_row_group_size}
178+
if self.use_dictionary is not None:
179+
write_kwargs["use_dictionary"] = self.use_dictionary
180+
pyarrow.parquet.write_table(pytable, where=tempfile, **write_kwargs)
178181

179182
tempfile.seek(0)
180183
write_buffer = tempfile.read()

tests/test_writer_parquet_features.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,105 @@ def test_parquet_sorting_single_column_list():
237237
shutil.rmtree("_temp_sort_single_list", ignore_errors=True)
238238

239239

240+
def test_parquet_dictionary_encoding_all():
241+
"""Test that use_dictionary parameter can be set to True for all columns"""
242+
shutil.rmtree("_temp_dict_all", ignore_errors=True)
243+
244+
w = BatchWriter(
245+
inner_writer=DiskWriter,
246+
dataset="_temp_dict_all",
247+
format="parquet",
248+
date=datetime.datetime.utcnow().date(),
249+
schema=[
250+
{"name": "id", "type": "INTEGER"},
251+
{"name": "category", "type": "VARCHAR"}
252+
],
253+
use_dictionary=True, # Enable dictionary encoding for all columns
254+
)
255+
256+
# Write records with repeated category values (good for dictionary encoding)
257+
for i in range(100):
258+
w.append({"id": i, "category": f"category_{i % 5}"})
259+
260+
w.finalize()
261+
262+
# Read back and verify data
263+
r = Reader(inner_reader=DiskReader, dataset="_temp_dict_all")
264+
records = list(r)
265+
assert len(records) == 100, f"Expected 100 records, got {len(records)}"
266+
267+
shutil.rmtree("_temp_dict_all", ignore_errors=True)
268+
269+
270+
def test_parquet_dictionary_encoding_disabled():
271+
"""Test that use_dictionary parameter can be set to False to disable dictionary encoding"""
272+
shutil.rmtree("_temp_dict_disabled", ignore_errors=True)
273+
274+
w = BatchWriter(
275+
inner_writer=DiskWriter,
276+
dataset="_temp_dict_disabled",
277+
format="parquet",
278+
date=datetime.datetime.utcnow().date(),
279+
schema=[
280+
{"name": "id", "type": "INTEGER"},
281+
{"name": "category", "type": "VARCHAR"}
282+
],
283+
use_dictionary=False, # Disable dictionary encoding
284+
)
285+
286+
# Write records
287+
for i in range(100):
288+
w.append({"id": i, "category": f"category_{i % 5}"})
289+
290+
w.finalize()
291+
292+
# Read back and verify data
293+
r = Reader(inner_reader=DiskReader, dataset="_temp_dict_disabled")
294+
records = list(r)
295+
assert len(records) == 100, f"Expected 100 records, got {len(records)}"
296+
297+
shutil.rmtree("_temp_dict_disabled", ignore_errors=True)
298+
299+
300+
def test_parquet_dictionary_encoding_specific_columns():
301+
"""Test that use_dictionary parameter can specify specific columns for dictionary encoding"""
302+
shutil.rmtree("_temp_dict_specific", ignore_errors=True)
303+
304+
w = BatchWriter(
305+
inner_writer=DiskWriter,
306+
dataset="_temp_dict_specific",
307+
format="parquet",
308+
date=datetime.datetime.utcnow().date(),
309+
schema=[
310+
{"name": "id", "type": "INTEGER"},
311+
{"name": "category", "type": "VARCHAR"},
312+
{"name": "value", "type": "VARCHAR"}
313+
],
314+
use_dictionary=["category"], # Only encode 'category' column with dictionary
315+
)
316+
317+
# Write records with repeated category values
318+
for i in range(100):
319+
w.append({
320+
"id": i,
321+
"category": f"category_{i % 5}",
322+
"value": f"unique_value_{i}"
323+
})
324+
325+
w.finalize()
326+
327+
# Read back and verify data
328+
r = Reader(inner_reader=DiskReader, dataset="_temp_dict_specific")
329+
records = list(r)
330+
assert len(records) == 100, f"Expected 100 records, got {len(records)}"
331+
332+
# Verify the data is correct
333+
assert records[0]["category"] == "category_0"
334+
assert records[50]["category"] == "category_0"
335+
336+
shutil.rmtree("_temp_dict_specific", ignore_errors=True)
337+
338+
240339
if __name__ == "__main__": # pragma: no cover
241340
from tests.helpers.runner import run_tests
242341

0 commit comments

Comments
 (0)