Skip to content

Commit 6c47797

Browse files
committed
feat: use content defined chunking in arrow_writer.ParquetWriter
1 parent c05150e commit 6c47797

File tree

3 files changed

+55
-10
lines changed

3 files changed

+55
-10
lines changed

src/datasets/arrow_writer.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,6 @@ def __init__(
341341
class ArrowWriter:
342342
"""Shuffles and writes Examples to Arrow files."""
343343

344-
_WRITER_CLASS = pa.RecordBatchStreamWriter
345-
346344
def __init__(
347345
self,
348346
schema: Optional[pa.Schema] = None,
@@ -430,7 +428,7 @@ def close(self):
430428
if self._closable_stream and not self.stream.closed:
431429
self.stream.close() # This also closes self.pa_writer if it is opened
432430

433-
def _build_writer(self, inferred_schema: pa.Schema):
431+
def _build_schema(self, inferred_schema: pa.Schema):
434432
schema = self.schema
435433
inferred_features = Features.from_arrow_schema(inferred_schema)
436434
if self._features is not None:
@@ -441,19 +439,24 @@ def _build_writer(self, inferred_schema: pa.Schema):
441439
if name in fields:
442440
if inferred_field == fields[name]:
443441
inferred_features[name] = self._features[name]
444-
self._features = inferred_features
442+
features = inferred_features
445443
schema: pa.Schema = inferred_schema
446444
else:
447-
self._features = inferred_features
445+
features = inferred_features
448446
schema: pa.Schema = inferred_features.arrow_schema
447+
449448
if self.disable_nullable:
450449
schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in schema)
451450
if self.with_metadata:
452-
schema = schema.with_metadata(self._build_metadata(DatasetInfo(features=self._features), self.fingerprint))
451+
schema = schema.with_metadata(self._build_metadata(DatasetInfo(features=features), self.fingerprint))
453452
else:
454453
schema = schema.with_metadata({})
455-
self._schema = schema
456-
self.pa_writer = self._WRITER_CLASS(self.stream, schema)
454+
455+
return schema, features
456+
457+
def _build_writer(self, inferred_schema: pa.Schema):
458+
self._schema, self._features = self._build_schema(inferred_schema)
459+
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, self._schema)
457460

458461
@property
459462
def schema(self):
@@ -674,4 +677,11 @@ def finalize(self, close_stream=True):
674677

675678

676679
class ParquetWriter(ArrowWriter):
677-
_WRITER_CLASS = pq.ParquetWriter
680+
def __init__(self, *args, cdc_options=None, **kwargs):
681+
super().__init__(*args, **kwargs)
682+
self.cdc_options = config.DEFAULT_CDC_OPTIONS if cdc_options is None else cdc_options
683+
684+
def _build_writer(self, inferred_schema: pa.Schema):
685+
self._schema, self._features = self._build_schema(inferred_schema)
686+
self.pa_writer = pq.ParquetWriter(self.stream, self._schema, use_content_defined_chunking=self.cdc_options)
687+
self.pa_writer.add_key_value_metadata({"content_defined_chunking": json.dumps(self.cdc_options)})

tests/io/test_parquet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def test_parquet_write_uses_content_defined_chunking(dataset, tmp_path):
213213
writer = ParquetDatasetWriter(dataset, tmp_path / "foo.parquet")
214214
writer.write()
215215
assert MockWriter.call_count == 1
216-
args, kwargs = MockWriter.call_args
216+
_, kwargs = MockWriter.call_args
217217
# Save or check the arguments as needed
218218
assert "use_content_defined_chunking" in kwargs
219219
assert kwargs["use_content_defined_chunking"] == config.DEFAULT_CDC_OPTIONS

tests/test_arrow_writer.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import json
23
import os
34
import tempfile
45
from unittest import TestCase
@@ -9,6 +10,7 @@
910
import pyarrow.parquet as pq
1011
import pytest
1112

13+
from datasets import config
1214
from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, ParquetWriter, TypedSequence
1315
from datasets.features import Array2D, ClassLabel, Features, Image, Value
1416
from datasets.features.features import Array2DExtensionType, cast_to_python_objects
@@ -334,6 +336,39 @@ def test_parquet_writer_write():
334336
assert pa_table.to_pydict() == {"col_1": ["foo", "bar"], "col_2": [1, 2]}
335337

336338

339+
custom_cdc_options = {
340+
"min_chunk_size": 128 * 1024, # 128 KiB
341+
"max_chunk_size": 512 * 1024, # 512 KiB
342+
"norm_level": 1,
343+
}
344+
345+
346+
@pytest.mark.parametrize(
347+
("cdc_options", "expected_options"), [(None, config.DEFAULT_CDC_OPTIONS), (custom_cdc_options, custom_cdc_options)]
348+
)
349+
def test_parquet_write_uses_content_defined_chunking(cdc_options, expected_options):
350+
output = pa.BufferOutputStream()
351+
with patch("pyarrow.parquet.ParquetWriter", wraps=pq.ParquetWriter) as MockWriter:
352+
with ParquetWriter(stream=output, cdc_options=cdc_options) as writer:
353+
writer.write({"col_1": "foo", "col_2": 1})
354+
writer.write({"col_1": "bar", "col_2": 2})
355+
writer.finalize()
356+
assert MockWriter.call_count == 1
357+
_, kwargs = MockWriter.call_args
358+
assert "use_content_defined_chunking" in kwargs
359+
assert kwargs["use_content_defined_chunking"] == expected_options
360+
361+
# read metadata from the output stream
362+
with pa.input_stream(output.getvalue()) as stream:
363+
metadata = pq.read_metadata(stream)
364+
key_value_metadata = metadata.metadata
365+
366+
# check that the content defined chunking options are persisted
367+
assert b"content_defined_chunking" in key_value_metadata
368+
json_encoded_options = key_value_metadata[b"content_defined_chunking"].decode("utf-8")
369+
assert json.loads(json_encoded_options) == expected_options
370+
371+
337372
@require_pil
338373
@pytest.mark.parametrize("embed_local_files", [False, True])
339374
def test_writer_embed_local_files(tmp_path, embed_local_files):

0 commit comments

Comments
 (0)