Skip to content

Commit 2027296

Browse files
committed
add _generate_shards
1 parent 38d28bf commit 2027296

File tree

14 files changed

+321
-227
lines changed

14 files changed

+321
-227
lines changed

src/datasets/builder.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,26 @@ class GeneratorBasedBuilder(DatasetBuilder):
13351335
(`_split_generators`). See the method docstrings for details.
13361336
"""
13371337

1338+
def _generate_shards(self, **kwargs) -> Iterator[Union[str, dict[str, Any]]]:
1339+
"""Default function generating shards paths for each `SplitGenerator`.
1340+
1341+
This function is useful to list the original shards from where the data
1342+
comes from and is either converted to Arrow or streamed to an IterableDataset.
1343+
1344+
This is optional and only used for certain utilities, but not in Dataset
1345+
nor IterableDataset. E.g. it's used to map original shard files to Parquet
1346+
files in the Dataset Viewer after conversion.
1347+
1348+
Args:
1349+
**kwargs (additional keyword arguments):
1350+
Arguments forwarded from the SplitGenerator.gen_kwargs
1351+
1352+
Yields:
1353+
shard: generally a string representing the shard path, or a dict
1354+
representing the shard in case of shards spanning intra or inter-files.
1355+
"""
1356+
raise NotImplementedError()
1357+
13381358
@abc.abstractmethod
13391359
def _generate_examples(self, **kwargs) -> Iterator[tuple[Key, dict[str, Any]]]:
13401360
"""Default function generating examples for each `SplitGenerator`.
@@ -1624,6 +1644,26 @@ def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> E
16241644
class ArrowBasedBuilder(DatasetBuilder):
16251645
"""Base class for datasets with data generation based on Arrow loading functions (CSV/JSON/Parquet)."""
16261646

1647+
def _generate_shards(self, **kwargs) -> Iterator[Union[str, dict[str, Any]]]:
1648+
"""Default function generating shards paths for each `SplitGenerator`.
1649+
1650+
This function is useful to list the original shards from where the data
1651+
comes from and is either converted to Arrow or streamed to an IterableDataset.
1652+
1653+
This is optional and only used for certain utilities, but not in Dataset
1654+
nor IterableDataset. E.g. it's used to map original shard files to Parquet
1655+
files in the Dataset Viewer after conversion.
1656+
1657+
Args:
1658+
**kwargs (additional keyword arguments):
1659+
Arguments forwarded from the SplitGenerator.gen_kwargs
1660+
1661+
Yields:
1662+
shard: generally a string representing the shard path, or a dict
1663+
representing the shard in case of shards spanning intra or inter-files.
1664+
"""
1665+
raise NotImplementedError()
1666+
16271667
@abc.abstractmethod
16281668
def _generate_tables(self, **kwargs) -> Iterator[tuple[Key, pa.Table]]:
16291669
"""Default function generating examples for each `SplitGenerator`.

src/datasets/packaged_modules/arrow/arrow.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
from dataclasses import dataclass
32
from typing import Optional
43

@@ -32,17 +31,12 @@ def _split_generators(self, dl_manager):
3231
"""We handle string, list and dicts in datafiles"""
3332
if not self.config.data_files:
3433
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
35-
dl_manager.download_config.extract_on_the_fly = True
36-
data_files = dl_manager.download_and_extract(self.config.data_files)
34+
data_files = dl_manager.download(self.config.data_files)
3735
splits = []
3836
for split_name, files in data_files.items():
39-
if isinstance(files, str):
40-
files = [files]
41-
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
42-
files = [dl_manager.iter_files(file) for file in files]
4337
# Infer features if they are stored in the arrow schema
4438
if self.info.features is None:
45-
for file in itertools.chain.from_iterable(files):
39+
for file in files:
4640
with open(file, "rb") as f:
4741
try:
4842
reader = pa.ipc.open_stream(f)
@@ -60,8 +54,11 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
6054
pa_table = table_cast(pa_table, self.info.features.arrow_schema)
6155
return pa_table
6256

57+
def _generate_shards(self, files):
58+
yield from files
59+
6360
def _generate_tables(self, files):
64-
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
61+
for file_idx, file in enumerate(files):
6562
with open(file, "rb") as f:
6663
try:
6764
try:

src/datasets/packaged_modules/cache/cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ def _split_generators(self, dl_manager):
176176
for split_info in split_infos
177177
]
178178

179+
def _generate_shards(self, files):
180+
yield from files
181+
179182
def _generate_tables(self, files):
180183
# used to stream from cache
181184
for file_idx, file in enumerate(files):

src/datasets/packaged_modules/csv/csv.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
from dataclasses import dataclass
32
from typing import Any, Callable, Optional, Union
43

@@ -154,13 +153,17 @@ def _split_generators(self, dl_manager):
154153
if not self.config.data_files:
155154
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
156155
dl_manager.download_config.extract_on_the_fly = True
157-
data_files = dl_manager.download_and_extract(self.config.data_files)
156+
base_data_files = dl_manager.download(self.config.data_files)
157+
extracted_data_files = dl_manager.extract(base_data_files)
158158
splits = []
159-
for split_name, files in data_files.items():
160-
if isinstance(files, str):
161-
files = [files]
162-
files = [dl_manager.iter_files(file) for file in files]
163-
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
159+
for split_name, extracted_files in extracted_data_files.items():
160+
files_iterables = [dl_manager.iter_files(extracted_file) for extracted_file in extracted_files]
161+
splits.append(
162+
datasets.SplitGenerator(
163+
name=split_name,
164+
gen_kwargs={"files_iterables": files_iterables, "base_files": base_data_files[split_name]},
165+
)
166+
)
164167
return splits
165168

166169
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
@@ -174,7 +177,10 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
174177
pa_table = table_cast(pa_table, schema)
175178
return pa_table
176179

177-
def _generate_tables(self, files):
180+
def _generate_shards(self, base_files, files_iterables):
181+
yield from base_files
182+
183+
def _generate_tables(self, base_files, files_iterables):
178184
schema = self.config.features.arrow_schema if self.config.features else None
179185
# dtype allows reading an int column as str
180186
dtype = (
@@ -185,15 +191,16 @@ def _generate_tables(self, files):
185191
if schema is not None
186192
else None
187193
)
188-
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
189-
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.pd_read_csv_kwargs)
190-
try:
191-
for batch_idx, df in enumerate(csv_file_reader):
192-
pa_table = pa.Table.from_pandas(df)
193-
# Uncomment for debugging (will print the Arrow table size and elements)
194-
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
195-
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
196-
yield Key(file_idx, batch_idx), self._cast_table(pa_table)
197-
except ValueError as e:
198-
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
199-
raise
194+
for shard_idx, files_iterable in enumerate(files_iterables):
195+
for file in files_iterable:
196+
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.pd_read_csv_kwargs)
197+
try:
198+
for batch_idx, df in enumerate(csv_file_reader):
199+
pa_table = pa.Table.from_pandas(df)
200+
# Uncomment for debugging (will print the Arrow table size and elements)
201+
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
202+
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
203+
yield Key(shard_idx, batch_idx), self._cast_table(pa_table)
204+
except ValueError as e:
205+
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
206+
raise

src/datasets/packaged_modules/eval/eval.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os
33
from itertools import islice
4+
from typing import Iterable
45

56
import pyarrow as pa
67

@@ -22,16 +23,26 @@ def _split_generators(self, dl_manager):
2223
if not self.config.data_files:
2324
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
2425
dl_manager.download_config.extract_on_the_fly = True
25-
data_files = dl_manager.download_and_extract(self.config.data_files)
26+
base_data_files = dl_manager.download(self.config.data_files)
27+
extracted_data_files = dl_manager.extract(base_data_files)
2628
splits = []
27-
for split_name, logs in data_files.items():
28-
if isinstance(logs, str):
29-
logs = [logs]
30-
logs_files = [dl_manager.iter_files(log) for log in logs]
31-
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"logs_files": logs_files}))
29+
for split_name, logs in extracted_data_files.items():
30+
logs_files_iterables = [dl_manager.iter_files(log) for log in logs]
31+
splits.append(
32+
datasets.SplitGenerator(
33+
name=split_name,
34+
gen_kwargs={
35+
"logs_files_iterables": logs_files_iterables,
36+
"base_files": base_data_files[split_name],
37+
},
38+
)
39+
)
3240
if not self.info.features:
3341
first_examples = list(
34-
islice(self._iter_samples_from_log_files(logs_files[0]), self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE)
42+
islice(
43+
self._iter_samples_from_log_files(logs_files_iterables[0]),
44+
self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE,
45+
)
3546
)
3647
pa_tables = [pa.Table.from_pylist([example]) for example in first_examples]
3748
inferred_arrow_schema = pa.concat_tables(pa_tables, promote_options="default").schema
@@ -44,7 +55,7 @@ def _sort_samples_key(self, sample_path: str):
4455
(sample_idx_str, epoch_idx_str) = os.path.splitext(os.path.basename(sample_path))[0].split("_epoch_")
4556
return (int(epoch_idx_str), int(sample_idx_str))
4657

47-
def _iter_samples_from_log_files(self, log_files: list[str]):
58+
def _iter_samples_from_log_files(self, log_files: Iterable[str]):
4859
sample_files = [log_file for log_file in log_files if os.path.basename(os.path.dirname(log_file)) == "samples"]
4960
sample_files.sort(key=self._sort_samples_key)
5061
for sample_file in sample_files:
@@ -57,7 +68,10 @@ def _iter_samples_from_log_files(self, log_files: list[str]):
5768
sample[field] = [json.dumps(x) for x in sample[field]]
5869
yield sample
5970

60-
def _generate_examples(self, logs_files):
61-
for file_idx, log_files in enumerate(logs_files):
71+
def _generate_shards(self, base_files, logs_files_iterables):
72+
yield from base_files
73+
74+
def _generate_examples(self, base_files, logs_files_iterables):
75+
for file_idx, log_files in enumerate(logs_files_iterables):
6276
for sample_idx, sample in enumerate(self._iter_samples_from_log_files(log_files)):
6377
yield Key(file_idx, sample_idx), sample

src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
112112
labels.add(os.path.basename(os.path.dirname(downloaded_dir_file)))
113113
path_depths.add(count_path_segments(downloaded_dir_file))
114114
elif os.path.basename(downloaded_dir_file) in metadata_filenames:
115-
metadata_files[split].add((None, downloaded_dir_file))
115+
metadata_files[split].add((None, downloaded_dir, downloaded_dir_file))
116116
else:
117117
archive_file_name = os.path.basename(archive)
118118
original_file_name = os.path.basename(downloaded_dir_file)
@@ -123,8 +123,6 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
123123
data_files = self.config.data_files
124124
splits = []
125125
for split_name, files in data_files.items():
126-
if isinstance(files, str):
127-
files = [files]
128126
files, archives = self._split_files_and_archives(files)
129127
downloaded_files = dl_manager.download(files)
130128
downloaded_dirs = dl_manager.download_and_extract(archives)
@@ -156,12 +154,17 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
156154
else:
157155
add_labels, add_metadata, metadata_files = False, False, {}
158156

157+
# files info (original_file, downloaded_file)
158+
files = tuple(zip(files, downloaded_files))
159+
# dirs info (original_file, downloaded_dir, downloaded_files)
160+
files += tuple(
161+
(None, downloaded_dir, dl_manager.iter_files(downloaded_dir)) for downloaded_dir in downloaded_dirs
162+
)
159163
splits.append(
160164
datasets.SplitGenerator(
161165
name=split_name,
162166
gen_kwargs={
163-
"files": tuple(zip(files, downloaded_files))
164-
+ tuple((None, dl_manager.iter_files(downloaded_dir)) for downloaded_dir in downloaded_dirs),
167+
"files": files,
165168
"metadata_files": metadata_files.get(split_name, []),
166169
"add_labels": add_labels,
167170
"add_metadata": add_metadata,
@@ -267,7 +270,7 @@ def _split_files_and_archives(self, data_files):
267270
files.append(data_file)
268271
elif os.path.basename(data_file) in metadata_filenames:
269272
files.append(data_file)
270-
else:
273+
elif data_file_ext.lower() == ".zip":
271274
archives.append(data_file)
272275
return files, archives
273276

@@ -354,6 +357,14 @@ def _read_metadata(self, metadata_file: str, metadata_ext: str = "") -> Iterator
354357
):
355358
yield pa.Table.from_batches([record_batch])
356359

360+
def _generate_shards(self, files, metadata_files, add_metadata, add_labels):
361+
if add_metadata:
362+
for original_metadata_file, downloaded_metadata_file in metadata_files:
363+
yield downloaded_metadata_file
364+
else:
365+
for original_file, downloaded_file_or_dir in files:
366+
yield downloaded_file_or_dir
367+
357368
def _generate_examples(self, files, metadata_files, add_metadata, add_labels):
358369
if add_metadata:
359370
feature_paths = []
@@ -365,7 +376,11 @@ def find_feature_path(feature, feature_path):
365376

366377
_visit_with_path(self.info.features, find_feature_path)
367378

368-
for shard_idx, (original_metadata_file, downloaded_metadata_file) in enumerate(metadata_files):
379+
for shard_idx, metadata_file_info in enumerate(metadata_files):
380+
if len(metadata_file_info) == 2:
381+
original_metadata_file, downloaded_metadata_file = metadata_file_info
382+
else:
383+
original_metadata_file, downloaded_metadata_dir, downloaded_metadata_file = metadata_file_info
369384
metadata_ext = os.path.splitext(original_metadata_file or downloaded_metadata_file)[-1]
370385
downloaded_metadata_dir = os.path.dirname(downloaded_metadata_file)
371386

@@ -395,12 +410,13 @@ def set_feature(item, feature_path: _VisitPath):
395410
if isinstance(self.config.filters, list)
396411
else self.config.filters
397412
)
398-
for shard_idx, (original_file, downloaded_file_or_dir) in enumerate(files):
399-
downloaded_files = [downloaded_file_or_dir] if original_file else downloaded_file_or_dir
413+
for shard_idx, file_or_dir_info in enumerate(files):
414+
if len(file_or_dir_info) == 2:
415+
original_file, downloaded_file = file_or_dir_info
416+
downloaded_files = [downloaded_file]
417+
else:
418+
original_file, downloaded_dir, downloaded_files = file_or_dir_info
400419
for sample_idx, downloaded_file in enumerate(downloaded_files):
401-
original_file_ext = os.path.splitext(original_file or downloaded_file)[-1]
402-
if original_file_ext.lower() not in self.EXTENSIONS:
403-
continue
404420
sample = {self.BASE_COLUMN_NAME: downloaded_file}
405421
if add_labels:
406422
sample["label"] = os.path.basename(os.path.dirname(original_file or downloaded_file))

src/datasets/packaged_modules/hdf5/hdf5.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
from dataclasses import dataclass, field
32
from typing import TYPE_CHECKING, Optional
43

@@ -51,29 +50,27 @@ def _split_generators(self, dl_manager):
5150

5251
if not self.config.data_files:
5352
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
54-
dl_manager.download_config.extract_on_the_fly = True
55-
data_files = dl_manager.download_and_extract(self.config.data_files)
53+
data_files = dl_manager.download(self.config.data_files)
5654
splits = []
5755
for split_name, files in data_files.items():
58-
if isinstance(files, str):
59-
files = [files]
60-
61-
files = [dl_manager.iter_files(file) for file in files]
6256
# Infer features from first file
6357
if self.info.features is None:
64-
for first_file in itertools.chain.from_iterable(files):
58+
for first_file in files:
6559
with open(first_file, "rb") as f:
6660
with h5py.File(f, "r") as h5:
6761
self.info.features = _recursive_infer_features(h5)
6862
break
6963
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
7064
return splits
7165

66+
def _generate_shards(self, files):
67+
yield from files
68+
7269
def _generate_tables(self, files):
7370
import h5py
7471

7572
batch_size_cfg = self.config.batch_size
76-
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
73+
for file_idx, file in enumerate(files):
7774
try:
7875
with open(file, "rb") as f:
7976
with h5py.File(f, "r") as h5:

0 commit comments

Comments
 (0)