4343from .keyhash import DuplicatedKeysError , KeyHasher
4444from .table import array_cast , cast_array_to_feature , embed_table_storage , table_cast
4545from .utils import logging
46- from .utils .py_utils import asdict , first_non_null_non_empty_value
46+ from .utils .py_utils import asdict , convert_file_size_to_int , first_non_null_non_empty_value
4747
4848
4949logger = logging .get_logger (__name__ )
5050
5151type_ = type # keep python's type function
5252
5353
54- def get_writer_batch_size (features : Optional [Features ]) -> Optional [int ]:
54+ def get_arrow_writer_batch_size_from_features (features : Optional [Features ]) -> Optional [int ]:
5555 """
56- Get the writer_batch_size that defines the maximum row group size in the parquet files.
57- The default in `datasets` is 1,000 but we lower it to 100 for image/audio datasets and 10 for videos.
56+ Get the writer_batch_size that defines the maximum record batch size in the arrow files based on configuration values.
57+ The default value is 100 for image/audio datasets and 10 for videos.
58+ This allows to avoid overflows in arrow buffers.
59+
60+ Args:
61+ features (`datasets.Features` or `None`):
62+ Dataset Features from `datasets`.
63+ Returns:
64+ writer_batch_size (`Optional[int]`):
65+ Writer batch size to pass to a dataset builder.
66+ If `None`, then it will use the `datasets` default, i.e. `datasets.config.DEFAULT_MAX_BATCH_SIZE`.
67+ """
68+ if not features :
69+ return None
70+
71+ batch_size = np .inf
72+
73+ def set_batch_size (feature : FeatureType ) -> None :
74+ nonlocal batch_size
75+ if isinstance (feature , Image ) and config .ARROW_RECORD_BATCH_SIZE_FOR_IMAGE_DATASETS is not None :
76+ batch_size = min (batch_size , config .ARROW_RECORD_BATCH_SIZE_FOR_IMAGE_DATASETS )
77+ elif isinstance (feature , Audio ) and config .ARROW_RECORD_BATCH_SIZE_FOR_AUDIO_DATASETS is not None :
78+ batch_size = min (batch_size , config .ARROW_RECORD_BATCH_SIZE_FOR_AUDIO_DATASETS )
79+ elif isinstance (feature , Video ) and config .ARROW_RECORD_BATCH_SIZE_FOR_VIDEO_DATASETS is not None :
80+ batch_size = min (batch_size , config .ARROW_RECORD_BATCH_SIZE_FOR_VIDEO_DATASETS )
81+ elif (
82+ isinstance (feature , Value )
83+ and feature .dtype == "binary"
84+ and config .ARROW_RECORD_BATCH_SIZE_FOR_BINARY_DATASETS is not None
85+ ):
86+ batch_size = min (batch_size , config .ARROW_RECORD_BATCH_SIZE_FOR_BINARY_DATASETS )
87+
88+ _visit (features , set_batch_size )
89+
90+ return None if batch_size is np .inf else batch_size
91+
92+
93+ def get_writer_batch_size_from_features (features : Optional [Features ]) -> Optional [int ]:
94+ """
95+ Get the writer_batch_size that defines the maximum row group size in the parquet files based on configuration values.
96+ By default these are not set, but it can be helpful to hard set those values in some cases.
5897 This allows to optimize random access to parquet file, since accessing 1 row requires
5998 to read its entire row group.
6099
61- This can be improved to get optimized size for querying/iterating
62- but at least it matches the dataset viewer expectations on HF.
63-
64100 Args:
65101 features (`datasets.Features` or `None`):
66102 Dataset Features from `datasets`.
67103 Returns:
68104 writer_batch_size (`Optional[int]`):
69- Writer batch size to pass to a dataset builder .
70- If `None`, then it will use the `datasets` default.
105+ Writer batch size to pass to a parquet writer .
106+ If `None`, then it will use the `datasets` default, i.e. aiming for row groups of 100MB .
71107 """
72108 if not features :
73109 return None
@@ -76,20 +112,48 @@ def get_writer_batch_size(features: Optional[Features]) -> Optional[int]:
76112
77113 def set_batch_size (feature : FeatureType ) -> None :
78114 nonlocal batch_size
79- if isinstance (feature , Image ):
115+ if isinstance (feature , Image ) and config . PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS is not None :
80116 batch_size = min (batch_size , config .PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS )
81- elif isinstance (feature , Audio ):
117+ elif isinstance (feature , Audio ) and config . PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS is not None :
82118 batch_size = min (batch_size , config .PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS )
83- elif isinstance (feature , Video ):
119+ elif isinstance (feature , Video ) and config . PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS is not None :
84120 batch_size = min (batch_size , config .PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS )
85- elif isinstance (feature , Value ) and feature .dtype == "binary" :
121+ elif (
122+ isinstance (feature , Value )
123+ and feature .dtype == "binary"
124+ and config .PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS is not None
125+ ):
86126 batch_size = min (batch_size , config .PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS )
87127
88128 _visit (features , set_batch_size )
89129
90130 return None if batch_size is np .inf else batch_size
91131
92132
133+ def get_writer_batch_size_from_data_size (num_rows : int , num_bytes : int ) -> int :
134+ """
135+ Get the writer_batch_size that defines the maximum row group size in the parquet files.
136+ The default in `datasets` is aiming for row groups of maximum 100MB uncompressed.
137+ This allows to optimize random access to parquet file, since accessing 1 row requires
138+ to read its entire row group.
139+
140+ This can be improved to get optimized size for querying/iterating
141+ but at least it matches the dataset viewer expectations on HF.
142+
143+ Args:
144+ num_rows (`int`):
145+ Number of rows in the dataset.
146+ num_bytes (`int`):
147+ Number of bytes in the dataset.
148+ For dataset with external files to embed (image, audio, videos), this can also be an
149+ estimate from `dataset._estimate_nbytes()`.
150+ Returns:
151+ writer_batch_size (`Optional[int]`):
152+ Writer batch size to pass to a parquet writer.
153+ """
154+ return max (10 , num_rows * convert_file_size_to_int (config .MAX_ROW_GROUP_SIZE ) // num_bytes )
155+
156+
93157class SchemaInferenceError (ValueError ):
94158 pass
95159
@@ -342,8 +406,6 @@ def __init__(
342406class ArrowWriter :
343407 """Shuffles and writes Examples to Arrow files."""
344408
345- _WRITER_CLASS = pa .RecordBatchStreamWriter
346-
347409 def __init__ (
348410 self ,
349411 schema : Optional [pa .Schema ] = None ,
@@ -397,7 +459,9 @@ def __init__(
397459 self .fingerprint = fingerprint
398460 self .disable_nullable = disable_nullable
399461 self .writer_batch_size = (
400- writer_batch_size or get_writer_batch_size (self ._features ) or config .DEFAULT_MAX_BATCH_SIZE
462+ writer_batch_size
463+ or get_arrow_writer_batch_size_from_features (self ._features )
464+ or config .DEFAULT_MAX_BATCH_SIZE
401465 )
402466 self .update_features = update_features
403467 self .with_metadata = with_metadata
@@ -431,8 +495,9 @@ def close(self):
431495 if self ._closable_stream and not self .stream .closed :
432496 self .stream .close () # This also closes self.pa_writer if it is opened
433497
434- def _build_writer (self , inferred_schema : pa .Schema ):
498+ def _build_schema (self , inferred_schema : pa .Schema ):
435499 schema = self .schema
500+ features = self ._features
436501 inferred_features = Features .from_arrow_schema (inferred_schema )
437502 if self ._features is not None :
438503 if self .update_features : # keep original features it they match, or update them
@@ -442,19 +507,24 @@ def _build_writer(self, inferred_schema: pa.Schema):
442507 if name in fields :
443508 if inferred_field == fields [name ]:
444509 inferred_features [name ] = self ._features [name ]
445- self . _features = inferred_features
510+ features = inferred_features
446511 schema : pa .Schema = inferred_schema
447512 else :
448- self . _features = inferred_features
513+ features = inferred_features
449514 schema : pa .Schema = inferred_features .arrow_schema
515+
450516 if self .disable_nullable :
451517 schema = pa .schema (pa .field (field .name , field .type , nullable = False ) for field in schema )
452518 if self .with_metadata :
453- schema = schema .with_metadata (self ._build_metadata (DatasetInfo (features = self . _features ), self .fingerprint ))
519+ schema = schema .with_metadata (self ._build_metadata (DatasetInfo (features = features ), self .fingerprint ))
454520 else :
455521 schema = schema .with_metadata ({})
456- self ._schema = schema
457- self .pa_writer = self ._WRITER_CLASS (self .stream , schema )
522+
523+ return schema , features
524+
525+ def _build_writer (self , inferred_schema : pa .Schema ):
526+ self ._schema , self ._features = self ._build_schema (inferred_schema )
527+ self .pa_writer = pa .RecordBatchStreamWriter (self .stream , self ._schema )
458528
459529 @property
460530 def schema (self ):
@@ -675,4 +745,22 @@ def finalize(self, close_stream=True):
675745
676746
677747class ParquetWriter (ArrowWriter ):
678- _WRITER_CLASS = pq .ParquetWriter
748+ def __init__ (self , * args , use_content_defined_chunking = True , write_page_index = True , ** kwargs ):
749+ super ().__init__ (* args , ** kwargs )
750+ if use_content_defined_chunking is True :
751+ use_content_defined_chunking = config .DEFAULT_CDC_OPTIONS
752+ self .use_content_defined_chunking = use_content_defined_chunking
753+ self .write_page_index = write_page_index
754+
755+ def _build_writer (self , inferred_schema : pa .Schema ):
756+ self ._schema , self ._features = self ._build_schema (inferred_schema )
757+ self .pa_writer = pq .ParquetWriter (
758+ self .stream ,
759+ self ._schema ,
760+ use_content_defined_chunking = self .use_content_defined_chunking ,
761+ write_page_index = self .write_page_index ,
762+ )
763+ if self .use_content_defined_chunking is not False :
764+ self .pa_writer .add_key_value_metadata (
765+ {"content_defined_chunking" : json .dumps (self .use_content_defined_chunking )}
766+ )
0 commit comments