Skip to content

Commit 6c66938

Browse files
committed
Reimplemented partial split download support (revival of huggingface#6832)
1 parent 6790e13 commit 6c66938

File tree

17 files changed

+294
-76
lines changed

17 files changed

+294
-76
lines changed

src/datasets/arrow_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def make_file_instructions(
120120
dataset_name=name,
121121
split=info.name,
122122
filetype_suffix=filetype_suffix,
123-
shard_lengths=name2shard_lengths[info.name],
123+
num_shards=len(name2shard_lengths[info.name] or ()),
124124
)
125125
for info in split_infos
126126
}

src/datasets/builder.py

Lines changed: 133 additions & 25 deletions
Large diffs are not rendered by default.

src/datasets/info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def __post_init__(self):
176176
else:
177177
self.version = Version.from_dict(self.version)
178178
if self.splits is not None and not isinstance(self.splits, SplitDict):
179-
self.splits = SplitDict.from_split_dict(self.splits)
179+
self.splits = SplitDict.from_split_dict(self.splits, self.dataset_name)
180180
if self.supervised_keys is not None and not isinstance(self.supervised_keys, SupervisedKeysData):
181181
if isinstance(self.supervised_keys, (tuple, list)):
182182
self.supervised_keys = SupervisedKeysData(*self.supervised_keys)

src/datasets/load.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,6 +1410,7 @@ def load_dataset(
14101410

14111411
# Download and prepare data
14121412
builder_instance.download_and_prepare(
1413+
split = split,
14131414
download_config=download_config,
14141415
download_mode=download_mode,
14151416
verification_mode=verification_mode,

src/datasets/naming.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import itertools
1919
import os
20+
import posixpath
2021
import re
2122

2223

@@ -46,33 +47,33 @@ def snakecase_to_camelcase(name):
4647

4748

4849
def filename_prefix_for_name(name):
49-
if os.path.basename(name) != name:
50+
if posixpath.basename(name) != name:
5051
raise ValueError(f"Should be a dataset name, not a path: {name}")
5152
return camelcase_to_snakecase(name)
5253

5354

5455
def filename_prefix_for_split(name, split):
55-
if os.path.basename(name) != name:
56+
if posixpath.basename(name) != name:
5657
raise ValueError(f"Should be a dataset name, not a path: {name}")
5758
if not re.match(_split_re, split):
5859
raise ValueError(f"Split name should match '{_split_re}'' but got '{split}'.")
5960
return f"{filename_prefix_for_name(name)}-{split}"
6061

6162

62-
def filepattern_for_dataset_split(dataset_name, split, data_dir, filetype_suffix=None):
63+
def filepattern_for_dataset_split(path, dataset_name, split, filetype_suffix=None):
6364
prefix = filename_prefix_for_split(dataset_name, split)
65+
filepath = posixpath.join(path, prefix)
66+
filepath = f"{filepath}*"
6467
if filetype_suffix:
65-
prefix += f".{filetype_suffix}"
66-
filepath = os.path.join(data_dir, prefix)
67-
return f"{filepath}*"
68+
filepath += f".{filetype_suffix}"
69+
return filepath
6870

6971

70-
def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, shard_lengths=None):
72+
def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, num_shards=1):
7173
prefix = filename_prefix_for_split(dataset_name, split)
72-
prefix = os.path.join(path, prefix)
74+
prefix = posixpath.join(path, prefix)
7375

74-
if shard_lengths:
75-
num_shards = len(shard_lengths)
76+
if num_shards > 1:
7677
filenames = [f"{prefix}-{shard_id:05d}-of-{num_shards:05d}" for shard_id in range(num_shards)]
7778
if filetype_suffix:
7879
filenames = [filename + f".{filetype_suffix}" for filename in filenames]

src/datasets/packaged_modules/arrow/arrow.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import itertools
22
from dataclasses import dataclass
3-
from typing import Optional
3+
from typing import List, Optional
44

55
import pyarrow as pa
66

@@ -27,12 +27,18 @@ class Arrow(datasets.ArrowBasedBuilder):
2727
def _info(self):
2828
return datasets.DatasetInfo(features=self.config.features)
2929

30-
def _split_generators(self, dl_manager):
30+
def _available_splits(self) -> Optional[List[str]]:
31+
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None
32+
33+
def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
3134
"""We handle string, list and dicts in datafiles"""
3235
if not self.config.data_files:
3336
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
3437
dl_manager.download_config.extract_on_the_fly = True
35-
data_files = dl_manager.download_and_extract(self.config.data_files)
38+
data_files = self.config.data_files
39+
if splits and isinstance(data_files, dict):
40+
data_files = {split: data_files[split] for split in splits}
41+
data_files = dl_manager.download_and_extract(data_files)
3642
splits = []
3743
for split_name, files in data_files.items():
3844
if isinstance(files, str):

src/datasets/packaged_modules/cache/cache.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,15 @@ def download_and_prepare(self, output_dir: Optional[str] = None, *args, **kwargs
153153
if output_dir is not None and output_dir != self.cache_dir:
154154
shutil.copytree(self.cache_dir, output_dir)
155155

156-
def _split_generators(self, dl_manager):
156+
def _available_splits(self) -> Optional[List[str]]:
157+
return [str(split) for split in self.info.splits]
158+
159+
def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
157160
# used to stream from cache
158161
if isinstance(self.info.splits, datasets.SplitDict):
159162
split_infos: list[datasets.SplitInfo] = list(self.info.splits.values())
163+
if splits:
164+
split_infos = [split_info for split_info in split_infos if split_info.name in splits]
160165
else:
161166
raise ValueError(f"Missing splits info for {self.dataset_name} in cache directory {self.cache_dir}")
162167
return [
@@ -168,7 +173,7 @@ def _split_generators(self, dl_manager):
168173
dataset_name=self.dataset_name,
169174
split=split_info.name,
170175
filetype_suffix="arrow",
171-
shard_lengths=split_info.shard_lengths,
176+
num_shards=len(split_info.shard_lengths or ()),
172177
)
173178
},
174179
)

src/datasets/packaged_modules/csv/csv.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,18 @@ class Csv(datasets.ArrowBasedBuilder):
148148
def _info(self):
149149
return datasets.DatasetInfo(features=self.config.features)
150150

151-
def _split_generators(self, dl_manager):
151+
def _available_splits(self) -> Optional[List[str]]:
152+
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None
153+
154+
def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
152155
"""We handle string, list and dicts in datafiles"""
153156
if not self.config.data_files:
154157
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
155158
dl_manager.download_config.extract_on_the_fly = True
156-
data_files = dl_manager.download_and_extract(self.config.data_files)
159+
data_files = self.config.data_files
160+
if splits and isinstance(data_files, dict):
161+
data_files = {split: data_files[split] for split in splits}
162+
data_files = dl_manager.download_and_extract(data_files)
157163
splits = []
158164
for split_name, files in data_files.items():
159165
if isinstance(files, str):

src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ def _info(self):
6767

6868
return datasets.DatasetInfo(features=self.config.features)
6969

70-
def _split_generators(self, dl_manager):
70+
def _available_splits(self) -> Optional[List[str]]:
71+
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None
72+
73+
def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
7174
if not self.config.data_files:
7275
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
7376
dl_manager.download_config.extract_on_the_fly = True
@@ -120,6 +123,8 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
120123
)
121124

122125
data_files = self.config.data_files
126+
if splits and isinstance(data_files, dict):
127+
data_files = {split: data_files[split] for split in splits}
123128
splits = []
124129
for split_name, files in data_files.items():
125130
if isinstance(files, str):

src/datasets/packaged_modules/json/json.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import io
22
import itertools
33
from dataclasses import dataclass
4-
from typing import Optional
4+
from typing import List, Optional
55

66
import pandas as pd
77
import pyarrow as pa
@@ -70,12 +70,18 @@ def _info(self):
7070
raise ValueError("The JSON loader parameter `newlines_in_values` is no longer supported")
7171
return datasets.DatasetInfo(features=self.config.features)
7272

73-
def _split_generators(self, dl_manager):
73+
def _available_splits(self) -> Optional[List[str]]:
74+
return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None
75+
76+
def _split_generators(self, dl_manager, splits: Optional[List[str]] = None):
7477
"""We handle string, list and dicts in datafiles"""
7578
if not self.config.data_files:
7679
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
7780
dl_manager.download_config.extract_on_the_fly = True
78-
data_files = dl_manager.download_and_extract(self.config.data_files)
81+
data_files = self.config.data_files
82+
if splits and isinstance(data_files, dict):
83+
data_files = {split: data_files[split] for split in splits}
84+
data_files = dl_manager.download_and_extract(data_files)
7985
splits = []
8086
for split_name, files in data_files.items():
8187
if isinstance(files, str):

0 commit comments

Comments
 (0)