Skip to content

Commit c5ca1d8

Browse files
lhoestqmariosasko
andauthored
Add writer_batch_size for ArrowBasedBuilder (#5565)
* add writer_batch_size to ArrowBasedBuilder * style * Update src/datasets/builder.py Co-authored-by: Mario Šaško <[email protected]> --------- Co-authored-by: Mario Šaško <[email protected]>
1 parent 778d4e1 commit c5ca1d8

File tree

1 file changed

+19
-23
lines changed

1 file changed

+19
-23
lines changed

src/datasets/builder.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,11 @@ class DatasetBuilder:
251251
`os.path.join(data_dir, "**")` as `data_files`.
252252
For builders that require manual download, it must be the path to the local directory containing the
253253
manually downloaded data.
254+
writer_batch_size (`int`, *optional*):
255+
Batch size used by the ArrowWriter.
256+
It defines the number of samples that are kept in memory before writing them
257+
and also the length of the arrow chunks.
258+
None means that the ArrowWriter will use its default value.
254259
name (`str`): Configuration name for the dataset.
255260
256261
<Deprecated version="2.3.0">
@@ -276,6 +281,12 @@ class DatasetBuilder:
276281
# Optional default config name to be used when name is None
277282
DEFAULT_CONFIG_NAME = None
278283

284+
# Default batch size used by the ArrowWriter
285+
# It defines the number of samples that are kept in memory before writing them
286+
# and also the length of the arrow chunks
287+
# None means that the ArrowWriter will use its default value
288+
DEFAULT_WRITER_BATCH_SIZE = None
289+
279290
def __init__(
280291
self,
281292
cache_dir: Optional[str] = None,
@@ -288,6 +299,7 @@ def __init__(
288299
repo_id: Optional[str] = None,
289300
data_files: Optional[Union[str, list, dict, DataFilesDict]] = None,
290301
data_dir: Optional[str] = None,
302+
writer_batch_size: Optional[int] = None,
291303
name="deprecated",
292304
**config_kwargs,
293305
):
@@ -303,6 +315,7 @@ def __init__(
303315
self.base_path = base_path
304316
self.use_auth_token = use_auth_token
305317
self.repo_id = repo_id
318+
self._writer_batch_size = writer_batch_size or self.DEFAULT_WRITER_BATCH_SIZE
306319

307320
if data_files is not None and not isinstance(data_files, DataFilesDict):
308321
data_files = DataFilesDict.from_local_or_remote(
@@ -1384,23 +1397,6 @@ class GeneratorBasedBuilder(DatasetBuilder):
13841397
(`_split_generators`). See the method docstrings for details.
13851398
"""
13861399

1387-
# GeneratorBasedBuilder should have dummy data for tests by default
1388-
test_dummy_data = True
1389-
1390-
# Default batch size used by the ArrowWriter
1391-
# It defines the number of samples that are kept in memory before writing them
1392-
# and also the length of the arrow chunks
1393-
# None means that the ArrowWriter will use its default value
1394-
DEFAULT_WRITER_BATCH_SIZE = None
1395-
1396-
def __init__(self, *args, writer_batch_size=None, **kwargs):
1397-
super().__init__(*args, **kwargs)
1398-
# Batch size used by the ArrowWriter
1399-
# It defines the number of samples that are kept in memory before writing them
1400-
# and also the length of the arrow chunks
1401-
# None means that the ArrowWriter will use its default value
1402-
self._writer_batch_size = writer_batch_size or self.DEFAULT_WRITER_BATCH_SIZE
1403-
14041400
@abc.abstractmethod
14051401
def _generate_examples(self, **kwargs):
14061402
"""Default function generating examples for each `SplitGenerator`.
@@ -1662,9 +1658,6 @@ def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> E
16621658
class ArrowBasedBuilder(DatasetBuilder):
16631659
"""Base class for datasets with data generation based on Arrow loading functions (CSV/JSON/Parquet)."""
16641660

1665-
# ArrowBasedBuilder should have dummy data for tests by default
1666-
test_dummy_data = True
1667-
16681661
@abc.abstractmethod
16691662
def _generate_tables(self, **kwargs):
16701663
"""Default function generating examples for each `SplitGenerator`.
@@ -1853,6 +1846,7 @@ def _prepare_split_single(
18531846
writer = writer_class(
18541847
features=self.info.features,
18551848
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1849+
writer_batch_size=self._writer_batch_size,
18561850
storage_options=self._fs.storage_options,
18571851
embed_local_files=embed_local_files,
18581852
)
@@ -1869,6 +1863,7 @@ def _prepare_split_single(
18691863
writer = writer_class(
18701864
features=writer._features,
18711865
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
1866+
writer_batch_size=self._writer_batch_size,
18721867
storage_options=self._fs.storage_options,
18731868
embed_local_files=embed_local_files,
18741869
)
@@ -1907,9 +1902,6 @@ class MissingBeamOptions(ValueError):
19071902
class BeamBasedBuilder(DatasetBuilder):
19081903
"""Beam-based Builder."""
19091904

1910-
# BeamBasedBuilder does not have dummy data for tests yet
1911-
test_dummy_data = False
1912-
19131905
def __init__(self, *args, beam_runner=None, beam_options=None, **kwargs):
19141906
self._beam_runner = beam_runner
19151907
self._beam_options = beam_options
@@ -1988,6 +1980,10 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_splits_
19881980
"`DirectRunner` (you may run out of memory). \nExample of usage: "
19891981
f"\n\t`{usage_example}`"
19901982
)
1983+
if self._writer_batch_size is not None:
1984+
logger.warning(
1985+
"`writer_batch_size` is not supported for beam pipelines yet. Using the default chunk size for writing."
1986+
)
19911987

19921988
# Beam type checking assumes transforms multiple outputs are of same type,
19931989
# which is not our case. Plus it doesn't handle correctly all types, so we

0 commit comments

Comments
 (0)