@@ -251,6 +251,11 @@ class DatasetBuilder:
251251 `os.path.join(data_dir, "**")` as `data_files`.
252252 For builders that require manual download, it must be the path to the local directory containing the
253253 manually downloaded data.
254+ writer_batch_size (`int`, *optional*):
255+ Batch size used by the ArrowWriter.
256+ It defines the number of samples that are kept in memory before writing them
257+ and also the length of the arrow chunks.
258+ None means that the ArrowWriter will use its default value.
254259 name (`str`): Configuration name for the dataset.
255260
256261 <Deprecated version="2.3.0">
@@ -276,6 +281,12 @@ class DatasetBuilder:
276281 # Optional default config name to be used when name is None
277282 DEFAULT_CONFIG_NAME = None
278283
284+ # Default batch size used by the ArrowWriter
285+ # It defines the number of samples that are kept in memory before writing them
286+ # and also the length of the arrow chunks
287+ # None means that the ArrowWriter will use its default value
288+ DEFAULT_WRITER_BATCH_SIZE = None
289+
279290 def __init__ (
280291 self ,
281292 cache_dir : Optional [str ] = None ,
@@ -288,6 +299,7 @@ def __init__(
288299 repo_id : Optional [str ] = None ,
289300 data_files : Optional [Union [str , list , dict , DataFilesDict ]] = None ,
290301 data_dir : Optional [str ] = None ,
302+ writer_batch_size : Optional [int ] = None ,
291303 name = "deprecated" ,
292304 ** config_kwargs ,
293305 ):
@@ -303,6 +315,7 @@ def __init__(
303315 self .base_path = base_path
304316 self .use_auth_token = use_auth_token
305317 self .repo_id = repo_id
318+ self ._writer_batch_size = writer_batch_size or self .DEFAULT_WRITER_BATCH_SIZE
306319
307320 if data_files is not None and not isinstance (data_files , DataFilesDict ):
308321 data_files = DataFilesDict .from_local_or_remote (
@@ -1384,23 +1397,6 @@ class GeneratorBasedBuilder(DatasetBuilder):
13841397 (`_split_generators`). See the method docstrings for details.
13851398 """
13861399
1387- # GeneratorBasedBuilder should have dummy data for tests by default
1388- test_dummy_data = True
1389-
1390- # Default batch size used by the ArrowWriter
1391- # It defines the number of samples that are kept in memory before writing them
1392- # and also the length of the arrow chunks
1393- # None means that the ArrowWriter will use its default value
1394- DEFAULT_WRITER_BATCH_SIZE = None
1395-
1396- def __init__ (self , * args , writer_batch_size = None , ** kwargs ):
1397- super ().__init__ (* args , ** kwargs )
1398- # Batch size used by the ArrowWriter
1399- # It defines the number of samples that are kept in memory before writing them
1400- # and also the length of the arrow chunks
1401- # None means that the ArrowWriter will use its default value
1402- self ._writer_batch_size = writer_batch_size or self .DEFAULT_WRITER_BATCH_SIZE
1403-
14041400 @abc .abstractmethod
14051401 def _generate_examples (self , ** kwargs ):
14061402 """Default function generating examples for each `SplitGenerator`.
@@ -1662,9 +1658,6 @@ def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> E
16621658class ArrowBasedBuilder (DatasetBuilder ):
16631659 """Base class for datasets with data generation based on Arrow loading functions (CSV/JSON/Parquet)."""
16641660
1665- # ArrowBasedBuilder should have dummy data for tests by default
1666- test_dummy_data = True
1667-
16681661 @abc .abstractmethod
16691662 def _generate_tables (self , ** kwargs ):
16701663 """Default function generating examples for each `SplitGenerator`.
@@ -1853,6 +1846,7 @@ def _prepare_split_single(
18531846 writer = writer_class (
18541847 features = self .info .features ,
18551848 path = fpath .replace ("SSSSS" , f"{ shard_id :05d} " ).replace ("JJJJJ" , f"{ job_id :05d} " ),
1849+ writer_batch_size = self ._writer_batch_size ,
18561850 storage_options = self ._fs .storage_options ,
18571851 embed_local_files = embed_local_files ,
18581852 )
@@ -1869,6 +1863,7 @@ def _prepare_split_single(
18691863 writer = writer_class (
18701864 features = writer ._features ,
18711865 path = fpath .replace ("SSSSS" , f"{ shard_id :05d} " ).replace ("JJJJJ" , f"{ job_id :05d} " ),
1866+ writer_batch_size = self ._writer_batch_size ,
18721867 storage_options = self ._fs .storage_options ,
18731868 embed_local_files = embed_local_files ,
18741869 )
@@ -1907,9 +1902,6 @@ class MissingBeamOptions(ValueError):
19071902class BeamBasedBuilder (DatasetBuilder ):
19081903 """Beam-based Builder."""
19091904
1910- # BeamBasedBuilder does not have dummy data for tests yet
1911- test_dummy_data = False
1912-
19131905 def __init__ (self , * args , beam_runner = None , beam_options = None , ** kwargs ):
19141906 self ._beam_runner = beam_runner
19151907 self ._beam_options = beam_options
@@ -1988,6 +1980,10 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_splits_
19881980 "`DirectRunner` (you may run out of memory). \n Example of usage: "
19891981 f"\n \t `{ usage_example } `"
19901982 )
1983+ if self ._writer_batch_size is not None :
1984+ logger .warning (
1985+ "`writer_batch_size` is not supported for beam pipelines yet. Using the default chunk size for writing."
1986+ )
19911987
19921988 # Beam type checking assumes transforms multiple outputs are of same type,
19931989 # which is not our case. Plus it doesn't handle correctly all types, so we
0 commit comments