Skip to content

Commit dc98f97

Browse files
authored
Add _generate_shards (#7943)
* add _generate_shards * fix tests * again * again
1 parent 38d28bf commit dc98f97

File tree

20 files changed

+371
-262
lines changed

20 files changed

+371
-262
lines changed

src/datasets/builder.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,26 @@ class GeneratorBasedBuilder(DatasetBuilder):
13351335
(`_split_generators`). See the method docstrings for details.
13361336
"""
13371337

1338+
def _generate_shards(self, **kwargs) -> Iterator[Union[str, dict[str, Any]]]:
1339+
"""Default function generating shards paths for each `SplitGenerator`.
1340+
1341+
This function is useful to list the original shards from where the data
1342+
comes from and is either converted to Arrow or streamed to an IterableDataset.
1343+
1344+
This is optional and only used for certain utilities, but not in Dataset
1345+
nor IterableDataset. E.g. it's used to map original shard files to Parquet
1346+
files in the Dataset Viewer after conversion.
1347+
1348+
Args:
1349+
**kwargs (additional keyword arguments):
1350+
Arguments forwarded from the SplitGenerator.gen_kwargs
1351+
1352+
Yields:
1353+
shard: generally a string representing the shard path, or a dict
1354+
representing the shard in case of shards spanning intra or inter-files.
1355+
"""
1356+
raise NotImplementedError()
1357+
13381358
@abc.abstractmethod
13391359
def _generate_examples(self, **kwargs) -> Iterator[tuple[Key, dict[str, Any]]]:
13401360
"""Default function generating examples for each `SplitGenerator`.
@@ -1624,6 +1644,26 @@ def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> E
16241644
class ArrowBasedBuilder(DatasetBuilder):
16251645
"""Base class for datasets with data generation based on Arrow loading functions (CSV/JSON/Parquet)."""
16261646

1647+
def _generate_shards(self, **kwargs) -> Iterator[Union[str, dict[str, Any]]]:
1648+
"""Default function generating shards paths for each `SplitGenerator`.
1649+
1650+
This function is useful to list the original shards from where the data
1651+
comes from and is either converted to Arrow or streamed to an IterableDataset.
1652+
1653+
This is optional and only used for certain utilities, but not in Dataset
1654+
nor IterableDataset. E.g. it's used to map original shard files to Parquet
1655+
files in the Dataset Viewer after conversion.
1656+
1657+
Args:
1658+
**kwargs (additional keyword arguments):
1659+
Arguments forwarded from the SplitGenerator.gen_kwargs
1660+
1661+
Yields:
1662+
shard: generally a string representing the shard path, or a dict
1663+
representing the shard in case of shards spanning intra or inter-files.
1664+
"""
1665+
raise NotImplementedError()
1666+
16271667
@abc.abstractmethod
16281668
def _generate_tables(self, **kwargs) -> Iterator[tuple[Key, pa.Table]]:
16291669
"""Default function generating examples for each `SplitGenerator`.

src/datasets/packaged_modules/arrow/arrow.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
from dataclasses import dataclass
32
from typing import Optional
43

@@ -32,17 +31,12 @@ def _split_generators(self, dl_manager):
3231
"""We handle string, list and dicts in datafiles"""
3332
if not self.config.data_files:
3433
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
35-
dl_manager.download_config.extract_on_the_fly = True
36-
data_files = dl_manager.download_and_extract(self.config.data_files)
34+
data_files = dl_manager.download(self.config.data_files)
3735
splits = []
3836
for split_name, files in data_files.items():
39-
if isinstance(files, str):
40-
files = [files]
41-
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
42-
files = [dl_manager.iter_files(file) for file in files]
4337
# Infer features if they are stored in the arrow schema
4438
if self.info.features is None:
45-
for file in itertools.chain.from_iterable(files):
39+
for file in files:
4640
with open(file, "rb") as f:
4741
try:
4842
reader = pa.ipc.open_stream(f)
@@ -60,8 +54,11 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
6054
pa_table = table_cast(pa_table, self.info.features.arrow_schema)
6155
return pa_table
6256

57+
def _generate_shards(self, files):
58+
yield from files
59+
6360
def _generate_tables(self, files):
64-
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
61+
for file_idx, file in enumerate(files):
6562
with open(file, "rb") as f:
6663
try:
6764
try:

src/datasets/packaged_modules/cache/cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ def _split_generators(self, dl_manager):
176176
for split_info in split_infos
177177
]
178178

179+
def _generate_shards(self, files):
180+
yield from files
181+
179182
def _generate_tables(self, files):
180183
# used to stream from cache
181184
for file_idx, file in enumerate(files):

src/datasets/packaged_modules/csv/csv.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
from dataclasses import dataclass
32
from typing import Any, Callable, Optional, Union
43

@@ -154,13 +153,17 @@ def _split_generators(self, dl_manager):
154153
if not self.config.data_files:
155154
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
156155
dl_manager.download_config.extract_on_the_fly = True
157-
data_files = dl_manager.download_and_extract(self.config.data_files)
156+
base_data_files = dl_manager.download(self.config.data_files)
157+
extracted_data_files = dl_manager.extract(base_data_files)
158158
splits = []
159-
for split_name, files in data_files.items():
160-
if isinstance(files, str):
161-
files = [files]
162-
files = [dl_manager.iter_files(file) for file in files]
163-
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
159+
for split_name, extracted_files in extracted_data_files.items():
160+
files_iterables = [dl_manager.iter_files(extracted_file) for extracted_file in extracted_files]
161+
splits.append(
162+
datasets.SplitGenerator(
163+
name=split_name,
164+
gen_kwargs={"files_iterables": files_iterables, "base_files": base_data_files[split_name]},
165+
)
166+
)
164167
return splits
165168

166169
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
@@ -174,7 +177,10 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
174177
pa_table = table_cast(pa_table, schema)
175178
return pa_table
176179

177-
def _generate_tables(self, files):
180+
def _generate_shards(self, base_files, files_iterables):
181+
yield from base_files
182+
183+
def _generate_tables(self, base_files, files_iterables):
178184
schema = self.config.features.arrow_schema if self.config.features else None
179185
# dtype allows reading an int column as str
180186
dtype = (
@@ -185,15 +191,16 @@ def _generate_tables(self, files):
185191
if schema is not None
186192
else None
187193
)
188-
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
189-
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.pd_read_csv_kwargs)
190-
try:
191-
for batch_idx, df in enumerate(csv_file_reader):
192-
pa_table = pa.Table.from_pandas(df)
193-
# Uncomment for debugging (will print the Arrow table size and elements)
194-
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
195-
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
196-
yield Key(file_idx, batch_idx), self._cast_table(pa_table)
197-
except ValueError as e:
198-
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
199-
raise
194+
for shard_idx, files_iterable in enumerate(files_iterables):
195+
for file in files_iterable:
196+
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.pd_read_csv_kwargs)
197+
try:
198+
for batch_idx, df in enumerate(csv_file_reader):
199+
pa_table = pa.Table.from_pandas(df)
200+
# Uncomment for debugging (will print the Arrow table size and elements)
201+
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
202+
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
203+
yield Key(shard_idx, batch_idx), self._cast_table(pa_table)
204+
except ValueError as e:
205+
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
206+
raise

src/datasets/packaged_modules/eval/eval.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os
33
from itertools import islice
4+
from typing import Iterable
45

56
import pyarrow as pa
67

@@ -22,16 +23,26 @@ def _split_generators(self, dl_manager):
2223
if not self.config.data_files:
2324
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
2425
dl_manager.download_config.extract_on_the_fly = True
25-
data_files = dl_manager.download_and_extract(self.config.data_files)
26+
base_data_files = dl_manager.download(self.config.data_files)
27+
extracted_data_files = dl_manager.extract(base_data_files)
2628
splits = []
27-
for split_name, logs in data_files.items():
28-
if isinstance(logs, str):
29-
logs = [logs]
30-
logs_files = [dl_manager.iter_files(log) for log in logs]
31-
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"logs_files": logs_files}))
29+
for split_name, logs in extracted_data_files.items():
30+
logs_files_iterables = [dl_manager.iter_files(log) for log in logs]
31+
splits.append(
32+
datasets.SplitGenerator(
33+
name=split_name,
34+
gen_kwargs={
35+
"logs_files_iterables": logs_files_iterables,
36+
"base_files": base_data_files[split_name],
37+
},
38+
)
39+
)
3240
if not self.info.features:
3341
first_examples = list(
34-
islice(self._iter_samples_from_log_files(logs_files[0]), self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE)
42+
islice(
43+
self._iter_samples_from_log_files(logs_files_iterables[0]),
44+
self.NUM_EXAMPLES_FOR_FEATURES_INFERENCE,
45+
)
3546
)
3647
pa_tables = [pa.Table.from_pylist([example]) for example in first_examples]
3748
inferred_arrow_schema = pa.concat_tables(pa_tables, promote_options="default").schema
@@ -44,7 +55,7 @@ def _sort_samples_key(self, sample_path: str):
4455
(sample_idx_str, epoch_idx_str) = os.path.splitext(os.path.basename(sample_path))[0].split("_epoch_")
4556
return (int(epoch_idx_str), int(sample_idx_str))
4657

47-
def _iter_samples_from_log_files(self, log_files: list[str]):
58+
def _iter_samples_from_log_files(self, log_files: Iterable[str]):
4859
sample_files = [log_file for log_file in log_files if os.path.basename(os.path.dirname(log_file)) == "samples"]
4960
sample_files.sort(key=self._sort_samples_key)
5061
for sample_file in sample_files:
@@ -57,7 +68,10 @@ def _iter_samples_from_log_files(self, log_files: list[str]):
5768
sample[field] = [json.dumps(x) for x in sample[field]]
5869
yield sample
5970

60-
def _generate_examples(self, logs_files):
61-
for file_idx, log_files in enumerate(logs_files):
71+
def _generate_shards(self, base_files, logs_files_iterables):
72+
yield from base_files
73+
74+
def _generate_examples(self, base_files, logs_files_iterables):
75+
for file_idx, log_files in enumerate(logs_files_iterables):
6276
for sample_idx, sample in enumerate(self._iter_samples_from_log_files(log_files)):
6377
yield Key(file_idx, sample_idx), sample

0 commit comments

Comments
 (0)