Skip to content

Commit b306a37

Browse files
committed
test: cover more input parameter values for ParquetWriter
1 parent ee1e73b commit b306a37

File tree

2 files changed

+60
-32
lines changed

2 files changed

+60
-32
lines changed

src/datasets/arrow_writer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -679,11 +679,11 @@ def finalize(self, close_stream=True):
679679

680680

681681
class ParquetWriter(ArrowWriter):
682-
def __init__(self, *args, use_content_defined_chunking=None, **kwargs):
682+
def __init__(self, *args, use_content_defined_chunking=True, **kwargs):
683683
super().__init__(*args, **kwargs)
684-
self.use_content_defined_chunking = (
685-
config.DEFAULT_CDC_OPTIONS if use_content_defined_chunking is None else use_content_defined_chunking
686-
)
684+
if use_content_defined_chunking is True:
685+
use_content_defined_chunking = config.DEFAULT_CDC_OPTIONS
686+
self.use_content_defined_chunking = use_content_defined_chunking
687687

688688
def _build_writer(self, inferred_schema: pa.Schema):
689689
self._schema, self._features = self._build_schema(inferred_schema)

tests/test_arrow_writer.py

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -336,37 +336,65 @@ def test_parquet_writer_write():
336336
assert pa_table.to_pydict() == {"col_1": ["foo", "bar"], "col_2": [1, 2]}
337337

338338

339-
custom_cdc_options = {
340-
"min_chunk_size": 128 * 1024, # 128 KiB
341-
"max_chunk_size": 512 * 1024, # 512 KiB
342-
"norm_level": 1,
343-
}
344-
339+
def test_parquet_writer_uses_content_defined_chunking():
340+
def write_and_get_argument_and_metadata(**kwargs):
341+
output = pa.BufferOutputStream()
342+
with patch("pyarrow.parquet.ParquetWriter", wraps=pq.ParquetWriter) as MockWriter:
343+
with ParquetWriter(stream=output, **kwargs) as writer:
344+
writer.write({"col_1": "foo", "col_2": 1})
345+
writer.write({"col_1": "bar", "col_2": 2})
346+
writer.finalize()
347+
assert MockWriter.call_count == 1
348+
_, kwargs = MockWriter.call_args
349+
assert "use_content_defined_chunking" in kwargs
350+
351+
# read metadata from the output stream
352+
with pa.input_stream(output.getvalue()) as stream:
353+
metadata = pq.read_metadata(stream)
354+
key_value_metadata = metadata.metadata
355+
356+
return kwargs["use_content_defined_chunking"], key_value_metadata
357+
358+
# not passing the use_content_defined_chunking argument, using the default
359+
passed_arg, key_value_metadata = write_and_get_argument_and_metadata()
360+
assert passed_arg == config.DEFAULT_CDC_OPTIONS
361+
assert b"content_defined_chunking" in key_value_metadata
362+
json_encoded_options = key_value_metadata[b"content_defined_chunking"].decode("utf-8")
363+
assert json.loads(json_encoded_options) == config.DEFAULT_CDC_OPTIONS
345364

346-
@pytest.mark.parametrize(
347-
("cdc_options", "expected_options"), [(None, config.DEFAULT_CDC_OPTIONS), (custom_cdc_options, custom_cdc_options)]
348-
)
349-
def test_parquet_write_uses_content_defined_chunking(cdc_options, expected_options):
350-
output = pa.BufferOutputStream()
351-
with patch("pyarrow.parquet.ParquetWriter", wraps=pq.ParquetWriter) as MockWriter:
352-
with ParquetWriter(stream=output, use_content_defined_chunking=cdc_options) as writer:
353-
writer.write({"col_1": "foo", "col_2": 1})
354-
writer.write({"col_1": "bar", "col_2": 2})
355-
writer.finalize()
356-
assert MockWriter.call_count == 1
357-
_, kwargs = MockWriter.call_args
358-
assert "use_content_defined_chunking" in kwargs
359-
assert kwargs["use_content_defined_chunking"] == expected_options
360-
361-
# read metadata from the output stream
362-
with pa.input_stream(output.getvalue()) as stream:
363-
metadata = pq.read_metadata(stream)
364-
key_value_metadata = metadata.metadata
365-
366-
# check that the content defined chunking options are persisted
365+
# passing True, using the default options
366+
passed_arg, key_value_metadata = write_and_get_argument_and_metadata(use_content_defined_chunking=True)
367+
assert passed_arg == config.DEFAULT_CDC_OPTIONS
368+
assert b"content_defined_chunking" in key_value_metadata
369+
json_encoded_options = key_value_metadata[b"content_defined_chunking"].decode("utf-8")
370+
assert json.loads(json_encoded_options) == config.DEFAULT_CDC_OPTIONS
371+
372+
# passing False, not using content defined chunking
373+
passed_arg, key_value_metadata = write_and_get_argument_and_metadata(use_content_defined_chunking=False)
374+
assert passed_arg is False
375+
assert b"content_defined_chunking" not in key_value_metadata
376+
377+
# passing custom options, using the custom options
378+
custom_cdc_options = {
379+
"min_chunk_size": 128 * 1024, # 128 KiB
380+
"max_chunk_size": 512 * 1024, # 512 KiB
381+
"norm_level": 1,
382+
}
383+
passed_arg, key_value_metadata = write_and_get_argument_and_metadata(
384+
use_content_defined_chunking=custom_cdc_options
385+
)
386+
assert passed_arg == custom_cdc_options
367387
assert b"content_defined_chunking" in key_value_metadata
368388
json_encoded_options = key_value_metadata[b"content_defined_chunking"].decode("utf-8")
369-
assert json.loads(json_encoded_options) == expected_options
389+
assert json.loads(json_encoded_options) == custom_cdc_options
390+
391+
# passing None or wrong options raise by pyarrow
392+
with pytest.raises(TypeError):
393+
write_and_get_argument_and_metadata(use_content_defined_chunking=None)
394+
with pytest.raises(TypeError):
395+
write_and_get_argument_and_metadata(use_content_defined_chunking="invalid_options")
396+
with pytest.raises(ValueError):
397+
write_and_get_argument_and_metadata(use_content_defined_chunking={"invalid_option": 1})
370398

371399

372400
@require_pil

0 commit comments

Comments
 (0)