Skip to content

Commit 00a8c54

Browse files
committed
test: cover more input parameter values for ParquetDatasetWriter
1 parent b306a37 commit 00a8c54

File tree

2 files changed

+40
-35
lines changed

2 files changed

+40
-35
lines changed

src/datasets/io/parquet.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,44 +78,37 @@ def __init__(
7878
path_or_buf: Union[PathLike, BinaryIO],
7979
batch_size: Optional[int] = None,
8080
storage_options: Optional[dict] = None,
81-
use_content_defined_chunking: Optional[dict] = None,
81+
use_content_defined_chunking: bool | dict = True,
8282
**parquet_writer_kwargs,
8383
):
8484
self.dataset = dataset
8585
self.path_or_buf = path_or_buf
8686
self.batch_size = batch_size or get_writer_batch_size(dataset.features)
8787
self.storage_options = storage_options or {}
8888
self.parquet_writer_kwargs = parquet_writer_kwargs
89+
if use_content_defined_chunking is True:
90+
use_content_defined_chunking = config.DEFAULT_CDC_OPTIONS
8991
self.use_content_defined_chunking = use_content_defined_chunking
9092

9193
def write(self) -> int:
9294
batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE
93-
use_content_defined_chunking = (
94-
config.DEFAULT_CDC_OPTIONS
95-
if self.use_content_defined_chunking is None
96-
else self.use_content_defined_chunking
97-
)
9895

9996
if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
10097
with fsspec.open(self.path_or_buf, "wb", **(self.storage_options or {})) as buffer:
10198
written = self._write(
10299
file_obj=buffer,
103100
batch_size=batch_size,
104-
use_content_defined_chunking=use_content_defined_chunking,
105101
**self.parquet_writer_kwargs,
106102
)
107103
else:
108104
written = self._write(
109105
file_obj=self.path_or_buf,
110106
batch_size=batch_size,
111-
use_content_defined_chunking=use_content_defined_chunking,
112107
**self.parquet_writer_kwargs,
113108
)
114109
return written
115110

116-
def _write(
117-
self, file_obj: BinaryIO, batch_size: int, use_content_defined_chunking: bool | dict, **parquet_writer_kwargs
118-
) -> int:
111+
def _write(self, file_obj: BinaryIO, batch_size: int, **parquet_writer_kwargs) -> int:
119112
"""Writes the pyarrow table as Parquet to a binary file handle.
120113
121114
Caller is responsible for opening and closing the handle.
@@ -125,7 +118,10 @@ def _write(
125118
schema = self.dataset.features.arrow_schema
126119

127120
writer = pq.ParquetWriter(
128-
file_obj, schema=schema, use_content_defined_chunking=use_content_defined_chunking, **parquet_writer_kwargs
121+
file_obj,
122+
schema=schema,
123+
use_content_defined_chunking=self.use_content_defined_chunking,
124+
**parquet_writer_kwargs,
129125
)
130126

131127
for offset in hf_tqdm(
@@ -142,8 +138,8 @@ def _write(
142138
written += batch.nbytes
143139

144140
# TODO(kszucs): we may want to persist multiple parameters
145-
if use_content_defined_chunking is not False:
146-
writer.add_key_value_metadata({"content_defined_chunking": json.dumps(use_content_defined_chunking)})
141+
if self.use_content_defined_chunking is not False:
142+
writer.add_key_value_metadata({"content_defined_chunking": json.dumps(self.use_content_defined_chunking)})
147143

148144
writer.close()
149145
return written

tests/io/test_parquet.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -219,29 +219,38 @@ def test_parquet_write_uses_content_defined_chunking(dataset, tmp_path):
219219
assert kwargs["use_content_defined_chunking"] == config.DEFAULT_CDC_OPTIONS
220220

221221

222-
custom_cdc_options = {
223-
"min_chunk_size": 128 * 1024, # 128 KiB
224-
"max_chunk_size": 512 * 1024, # 512 KiB
225-
"norm_level": 1,
226-
}
227-
228-
229-
@pytest.mark.parametrize(
230-
("cdc_options", "expected_options"), [(None, config.DEFAULT_CDC_OPTIONS), (custom_cdc_options, custom_cdc_options)]
231-
)
232-
def test_parquet_writer_persist_cdc_options_as_metadata(dataset, tmp_path, cdc_options, expected_options):
233-
# write the dataset to parquet with the default CDC options
234-
writer = ParquetDatasetWriter(dataset, tmp_path / "foo.parquet", use_content_defined_chunking=cdc_options)
235-
assert writer.write() > 0
236-
237-
# read the parquet KV metadata
238-
metadata = pq.read_metadata(tmp_path / "foo.parquet")
239-
key_value_metadata = metadata.metadata
240-
241-
# check that the content defined chunking options are persisted
222+
def test_parquet_writer_persist_cdc_options_as_metadata(dataset, tmp_path):
223+
def write_and_get_metadata(**kwargs):
224+
# write the dataset to parquet with the default CDC options
225+
writer = ParquetDatasetWriter(dataset, tmp_path / "foo.parquet", **kwargs)
226+
assert writer.write() > 0
227+
228+
# read the parquet KV metadata
229+
metadata = pq.read_metadata(tmp_path / "foo.parquet")
230+
key_value_metadata = metadata.metadata
231+
232+
return key_value_metadata
233+
234+
# by default no arguments are passed, same as passing True using the default options
235+
for key_value_metadata in [write_and_get_metadata(), write_and_get_metadata(use_content_defined_chunking=True)]:
236+
assert b"content_defined_chunking" in key_value_metadata
237+
json_encoded_options = key_value_metadata[b"content_defined_chunking"].decode("utf-8")
238+
assert json.loads(json_encoded_options) == config.DEFAULT_CDC_OPTIONS
239+
240+
# passing False disables the content defined chunking and doesn't persist the options in metadata
241+
key_value_metadata = write_and_get_metadata(use_content_defined_chunking=False)
242+
assert b"content_defined_chunking" not in key_value_metadata
243+
244+
# passing custom options, using the custom options
245+
custom_cdc_options = {
246+
"min_chunk_size": 128 * 1024, # 128 KiB
247+
"max_chunk_size": 512 * 1024, # 512 KiB
248+
"norm_level": 1,
249+
}
250+
key_value_metadata = write_and_get_metadata(use_content_defined_chunking=custom_cdc_options)
242251
assert b"content_defined_chunking" in key_value_metadata
243252
json_encoded_options = key_value_metadata[b"content_defined_chunking"].decode("utf-8")
244-
assert json.loads(json_encoded_options) == expected_options
253+
assert json.loads(json_encoded_options) == custom_cdc_options
245254

246255

247256
def test_dataset_to_parquet_keeps_features(shared_datadir, tmp_path):

0 commit comments

Comments
 (0)