diff --git a/src/datasets/arrow_reader.py b/src/datasets/arrow_reader.py index 3bbb58a59c3..7011dbc9585 100644 --- a/src/datasets/arrow_reader.py +++ b/src/datasets/arrow_reader.py @@ -120,7 +120,7 @@ def make_file_instructions( dataset_name=name, split=info.name, filetype_suffix=filetype_suffix, - shard_lengths=name2shard_lengths[info.name], + num_shards=len(name2shard_lengths[info.name] or ()), ) for info in split_infos } diff --git a/src/datasets/builder.py b/src/datasets/builder.py index e63960dcabf..f4f13194284 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -18,6 +18,7 @@ import abc import contextlib import copy +import fnmatch import inspect import os import posixpath @@ -29,7 +30,7 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, Union from unittest.mock import patch import fsspec @@ -59,7 +60,12 @@ from .info import DatasetInfo, PostProcessedInfo from .iterable_dataset import ArrowExamplesIterable, ExamplesIterable, IterableDataset from .keyhash import DuplicatedKeysError -from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase +from .naming import ( + INVALID_WINDOWS_CHARACTERS_IN_PATH, + camelcase_to_snakecase, + filenames_for_dataset_split, + filepattern_for_dataset_split, +) from .splits import Split, SplitDict, SplitGenerator, SplitInfo from .streaming import extend_dataset_builder_for_streaming from .table import CastError @@ -69,6 +75,7 @@ from .utils.file_utils import is_remote_url from .utils.info_utils import VerificationMode, get_size_checksum_dict, verify_checksums, verify_splits from .utils.py_utils import ( + NestedDataStructure, classproperty, convert_file_size_to_int, has_sufficient_disk_space, @@ -77,6 +84,7 @@ memoize, size_str, temporary_assignment, + unique_values, ) from .utils.sharding import _number_of_shards_in_gen_kwargs, _split_gen_kwargs from .utils.track import tracked_list @@ -679,6 +687,10 @@ def _info(self) -> DatasetInfo: info: (DatasetInfo) The dataset information """ raise NotImplementedError + + def _supports_partial_generation(self) -> bool: + """Whether the dataset supports generation of specific splits.""" + return hasattr(self, "_available_splits") and "splits" in inspect.signature(self._split_generators).parameters @classmethod def get_imported_module_dir(cls): @@ -691,6 +703,7 @@ def _rename(self, src: str, dst: str): def download_and_prepare( self, output_dir: Optional[str] = None, + split: Optional[Union[str, ReadInstruction, Split]] = None, download_config: Optional[DownloadConfig] = None, download_mode: Optional[Union[DownloadMode, str]] = None, verification_mode: Optional[Union[VerificationMode, str]] = None, @@ -710,6 +723,8 @@ def download_and_prepare( Default to this builder's `cache_dir`, which is inside `~/.cache/huggingface/datasets` by default. + split (`Union[str, ReadInstruction, Split]`, *optional*): + Splits to generate. Default to all splits. download_config (`DownloadConfig`, *optional*): Specific download configuration parameters. download_mode ([`DownloadMode`] or `str`, *optional*): @@ -828,12 +843,60 @@ def download_and_prepare( # File locking only with local paths; no file locking on GCS or S3 with FileLock(lock_path) if is_local else contextlib.nullcontext(): # Check if the data already exists - data_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME)) - if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: - logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})") + info_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME)) + if info_exists: # We need to update the info in case some splits were added in the meantime # for example when calling load_dataset from multiple workers. self.info = self._load_info() + _dataset_name = self.name if self._check_legacy_cache() else self.dataset_name + splits: Optional[List[str]] = None + cached_split_filepatterns = [] + supports_partial_generation = self._supports_partial_generation() + if supports_partial_generation: + if split: + splits = [] + for split in NestedDataStructure(split).flatten(): + if not isinstance(split, ReadInstruction): + split = str(split) + if split == Split.ALL: + splits = None # generate all splits + break + split = ReadInstruction.from_spec(split) + split_names = [rel_instr.splitname for rel_instr in split._relative_instructions] + splits.extend(split_names) + splits = list(unique_values(splits)) # remove duplicates + available_splits = self._available_splits() + if splits is None: + splits = available_splits + missing_splits = set(splits) - set(available_splits) + if missing_splits: + raise ValueError(f"Splits {list(missing_splits)} not found. Available splits: {available_splits}") + if DownloadMode.REUSE_DATASET_IF_EXISTS: + for split_name in splits[:]: + num_shards = 1 + if self.info.splits is not None: + try: + num_shards = len(self.info.splits[split_name].shard_lengths or ()) + except Exception: + pass + split_filenames = filenames_for_dataset_split( + self._output_dir, + _dataset_name, + split_name, + filetype_suffix=file_format, + num_shards=num_shards, + ) + if self._fs.exists(split_filenames[0]): + splits.remove(split_name) + split_filepattern = filepattern_for_dataset_split( + self._output_dir, _dataset_name, split_name, filetype_suffix=file_format + ) + cached_split_filepatterns.append(split_filepattern) + # We cannot use info as the source of truth if the builder supports partial generation + # as the info can be incomplete in that case + requested_splits_exist = not splits if supports_partial_generation else info_exists + if requested_splits_exist and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: + logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})") self.download_post_processing_resources(dl_manager) return @@ -858,16 +921,33 @@ def incomplete_dir(dirname): try: yield tmp_dir if os.path.isdir(dirname): - shutil.rmtree(dirname) + for root, dirnames, filenames in os.walk(dirname, topdown=False): # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory - shutil.move(tmp_dir, dirname) + for filename in filenames: + filename = os.path.join(root, filename) + delete_filename = True + for cached_split_filepattern in cached_split_filepatterns: + if fnmatch.fnmatch(filename, cached_split_filepattern): + delete_filename = False + break + if delete_filename: + os.remove(filename) + for dirname in dirnames: + dirname = os.path.join(root, dirname) + if len(os.listdir(dirname)) == 0: + os.rmdir(dirname) + for file_or_dir in os.listdir(tmp_dir): + try: + shutil.move(os.path.join(tmp_dir, file_or_dir), dirname) + except shutil.Error: + # If the file already exists in the distributed setup + pass + else: + shutil.move(tmp_dir, dirname) finally: if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir) - # Print is intentional: we want this to always go to stdout so user has - # information needed to cancel download/preparation if needed. - # This comes right before the progress bar. if self.info.size_in_bytes: logger.info( f"Downloading and preparing dataset {self.dataset_name}/{self.config.name} " @@ -886,7 +966,7 @@ def incomplete_dir(dirname): # Temporarily assign _output_dir to tmp_data_dir to avoid having to forward # it to every sub function. with temporary_assignment(self, "_output_dir", tmp_output_dir): - prepare_split_kwargs = {"file_format": file_format} + prepare_split_kwargs = {"file_format": file_format, "splits": splits} if max_shard_size is not None: prepare_split_kwargs["max_shard_size"] = max_shard_size if num_proc is not None: @@ -898,7 +978,15 @@ def incomplete_dir(dirname): **download_and_prepare_kwargs, ) # Sync info + if supports_partial_generation and self.info.download_checksums is not None: + self.info.download_checksums.update(dl_manager.get_recorded_sizes_checksums()) + else: + self.info.download_checksums = dl_manager.get_recorded_sizes_checksums() + self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values()) + self.info.download_size = sum( + checksum["num_bytes"] for checksum in self.info.download_checksums.values() + ) self.info.download_checksums = dl_manager.get_recorded_sizes_checksums() if self.info.download_size is not None: self.info.size_in_bytes = self.info.dataset_size + self.info.download_size @@ -942,7 +1030,8 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k if `NO_CHECKS`, do not perform any verification. prepare_split_kwargs: Additional options, such as `file_format`, `max_shard_size` """ - # Generating data for all splits + # If `splits` is specified and the builder supports `splits` in `_split_generators`, then only generate the specified splits. + # Otherwise, generate all splits split_dict = SplitDict(dataset_name=self.dataset_name) split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs) split_generators = self._split_generators(dl_manager, **split_generators_kwargs) @@ -950,7 +1039,9 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k # Checksums verification if verification_mode == VerificationMode.ALL_CHECKS and dl_manager.record_checksums: verify_checksums( - self.info.download_checksums, dl_manager.get_recorded_sizes_checksums(), "dataset source files" + self.info.download_checksums, + dl_manager.get_recorded_sizes_checksums(), + "dataset source files", ) # Build splits @@ -987,9 +1078,18 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k if verification_mode == VerificationMode.BASIC_CHECKS or verification_mode == VerificationMode.ALL_CHECKS: verify_splits(self.info.splits, split_dict) - # Update the info object with the splits. - self.info.splits = split_dict - self.info.download_size = dl_manager.downloaded_size + # Update the info object with the generated splits. + if self._supports_partial_generation(): + split_infos = self.info.splits or {} + ordered_split_infos = {} + for split_name in self._available_splits(): + if split_name in split_dict: + ordered_split_infos[split_name] = split_dict[split_name] + elif split_name in split_infos: + ordered_split_infos[split_name] = split_infos[split_name] + self.info.splits = SplitDict.from_split_dict(ordered_split_infos, dataset_name=self.dataset_name) + else: + self.info.splits = split_dict def download_post_processing_resources(self, dl_manager): for split in self.info.splits or []: @@ -1021,7 +1121,9 @@ def _save_info(self): def _make_split_generators_kwargs(self, prepare_split_kwargs): """Get kwargs for `self._split_generators()` from `prepare_split_kwargs`.""" - del prepare_split_kwargs + splits = prepare_split_kwargs.pop("splits", None) + if self._supports_partial_generation(): + return {"splits": splits} return {} def as_dataset( @@ -1075,11 +1177,12 @@ def as_dataset( "datasets.load_dataset() before trying to access the Dataset object." ) - logger.debug(f"Constructing Dataset for split {split or ', '.join(self.info.splits)}, from {self._output_dir}") + available_splits = self._available_splits() if self._supports_partial_generation() else self.info.splits + logger.debug(f'Constructing Dataset for split {split or ", ".join(available_splits)}, from {self._output_dir}') # By default, return all splits if split is None: - split = {s: s for s in self.info.splits} + split = {s: s for s in available_splits} verification_mode = VerificationMode(verification_mode or VerificationMode.BASIC_CHECKS) @@ -1107,10 +1210,11 @@ def _build_single_dataset( in_memory: bool = False, ): """as_dataset for a single split.""" + available_splits = self._available_splits() if self._supports_partial_generation() else self.info.splits if not isinstance(split, ReadInstruction): split = str(split) - if split == "all": - split = "+".join(self.info.splits.keys()) + if split == Split.ALL: + split = "+".join(available_splits) split = Split(split) # Build base dataset @@ -1222,8 +1326,12 @@ def as_streaming_dataset( data_dir=self.config.data_dir, ) self._check_manual_download(dl_manager) - splits_generators = {sg.name: sg for sg in self._split_generators(dl_manager)} - # By default, return all splits + splits_generators_kwargs = {} + if self._supports_partial_generation(): + splits_generators_kwargs["splits"] = [split] if split else None + splits_generators = {sg.name: sg for sg in self._split_generators(dl_manager, **splits_generators_kwargs)} + # We still need this in case the builder's `_splits_generators` does not support the `splits` argument + # to filter the splits if split is None: splits_generator = splits_generators elif split in splits_generators: @@ -1403,9 +1511,9 @@ def _prepare_split( ): max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) - if self.info.splits is not None: + try: split_info = self.info.splits[split_generator.name] - else: + except Exception: split_info = split_generator.split_info SUFFIX = "-JJJJJ-SSSSS-of-NNNNN" diff --git a/src/datasets/info.py b/src/datasets/info.py index 3723439fb91..1a1e86b5e10 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -176,7 +176,7 @@ def __post_init__(self): else: self.version = Version.from_dict(self.version) if self.splits is not None and not isinstance(self.splits, SplitDict): - self.splits = SplitDict.from_split_dict(self.splits) + self.splits = SplitDict.from_split_dict(self.splits, self.dataset_name) if self.supervised_keys is not None and not isinstance(self.supervised_keys, SupervisedKeysData): if isinstance(self.supervised_keys, (tuple, list)): self.supervised_keys = SupervisedKeysData(*self.supervised_keys) diff --git a/src/datasets/load.py b/src/datasets/load.py index bc2b0e679b6..17e2bd8d13d 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1410,6 +1410,7 @@ def load_dataset( # Download and prepare data builder_instance.download_and_prepare( + split = split, download_config=download_config, download_mode=download_mode, verification_mode=verification_mode, diff --git a/src/datasets/naming.py b/src/datasets/naming.py index 65e7ede10dc..70eeb63e423 100644 --- a/src/datasets/naming.py +++ b/src/datasets/naming.py @@ -17,6 +17,7 @@ import itertools import os +import posixpath import re @@ -46,33 +47,33 @@ def snakecase_to_camelcase(name): def filename_prefix_for_name(name): - if os.path.basename(name) != name: + if posixpath.basename(name) != name: raise ValueError(f"Should be a dataset name, not a path: {name}") return camelcase_to_snakecase(name) def filename_prefix_for_split(name, split): - if os.path.basename(name) != name: + if posixpath.basename(name) != name: raise ValueError(f"Should be a dataset name, not a path: {name}") if not re.match(_split_re, split): raise ValueError(f"Split name should match '{_split_re}'' but got '{split}'.") return f"{filename_prefix_for_name(name)}-{split}" -def filepattern_for_dataset_split(dataset_name, split, data_dir, filetype_suffix=None): +def filepattern_for_dataset_split(path, dataset_name, split, filetype_suffix=None): prefix = filename_prefix_for_split(dataset_name, split) + filepath = posixpath.join(path, prefix) + filepath = f"{filepath}*" if filetype_suffix: - prefix += f".{filetype_suffix}" - filepath = os.path.join(data_dir, prefix) - return f"{filepath}*" + filepath += f".{filetype_suffix}" + return filepath -def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, shard_lengths=None): +def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, num_shards=1): prefix = filename_prefix_for_split(dataset_name, split) - prefix = os.path.join(path, prefix) + prefix = posixpath.join(path, prefix) - if shard_lengths: - num_shards = len(shard_lengths) + if num_shards > 1: filenames = [f"{prefix}-{shard_id:05d}-of-{num_shards:05d}" for shard_id in range(num_shards)] if filetype_suffix: filenames = [filename + f".{filetype_suffix}" for filename in filenames] diff --git a/src/datasets/packaged_modules/arrow/arrow.py b/src/datasets/packaged_modules/arrow/arrow.py index bcf31c473d2..501179d6d54 100644 --- a/src/datasets/packaged_modules/arrow/arrow.py +++ b/src/datasets/packaged_modules/arrow/arrow.py @@ -1,6 +1,6 @@ import itertools from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import pyarrow as pa @@ -27,12 +27,18 @@ class Arrow(datasets.ArrowBasedBuilder): def _info(self): return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): """We handle string, list and dicts in datafiles""" if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True - data_files = dl_manager.download_and_extract(self.config.data_files) + data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} + data_files = dl_manager.download_and_extract(data_files) splits = [] for split_name, files in data_files.items(): if isinstance(files, str): diff --git a/src/datasets/packaged_modules/cache/cache.py b/src/datasets/packaged_modules/cache/cache.py index cdcfb4c20b6..a60b59af64d 100644 --- a/src/datasets/packaged_modules/cache/cache.py +++ b/src/datasets/packaged_modules/cache/cache.py @@ -153,10 +153,15 @@ def download_and_prepare(self, output_dir: Optional[str] = None, *args, **kwargs if output_dir is not None and output_dir != self.cache_dir: shutil.copytree(self.cache_dir, output_dir) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.info.splits] + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): # used to stream from cache if isinstance(self.info.splits, datasets.SplitDict): split_infos: list[datasets.SplitInfo] = list(self.info.splits.values()) + if splits: + split_infos = [split_info for split_info in split_infos if split_info.name in splits] else: raise ValueError(f"Missing splits info for {self.dataset_name} in cache directory {self.cache_dir}") return [ @@ -168,7 +173,7 @@ def _split_generators(self, dl_manager): dataset_name=self.dataset_name, split=split_info.name, filetype_suffix="arrow", - shard_lengths=split_info.shard_lengths, + num_shards=len(split_info.shard_lengths or ()), ) }, ) diff --git a/src/datasets/packaged_modules/csv/csv.py b/src/datasets/packaged_modules/csv/csv.py index 2ae95ff5142..18edb6e23be 100644 --- a/src/datasets/packaged_modules/csv/csv.py +++ b/src/datasets/packaged_modules/csv/csv.py @@ -148,12 +148,18 @@ class Csv(datasets.ArrowBasedBuilder): def _info(self): return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): """We handle string, list and dicts in datafiles""" if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True - data_files = dl_manager.download_and_extract(self.config.data_files) + data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} + data_files = dl_manager.download_and_extract(data_files) splits = [] for split_name, files in data_files.items(): if isinstance(files, str): diff --git a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py index 182de467b14..07b4822753e 100644 --- a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py +++ b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py @@ -67,7 +67,10 @@ def _info(self): return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True @@ -120,6 +123,8 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split): ) data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} splits = [] for split_name, files in data_files.items(): if isinstance(files, str): diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index c5d8bcd03fc..75cf6fe0fea 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -1,7 +1,7 @@ import io import itertools from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import pandas as pd import pyarrow as pa @@ -70,12 +70,18 @@ def _info(self): raise ValueError("The JSON loader parameter `newlines_in_values` is no longer supported") return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): """We handle string, list and dicts in datafiles""" if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True - data_files = dl_manager.download_and_extract(self.config.data_files) + data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} + data_files = dl_manager.download_and_extract(data_files) splits = [] for split_name, files in data_files.items(): if isinstance(files, str): diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py index 10797753657..f9129a57065 100644 --- a/src/datasets/packaged_modules/parquet/parquet.py +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -41,12 +41,18 @@ def _info(self): ) return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): """We handle string, list and dicts in datafiles""" if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True - data_files = dl_manager.download_and_extract(self.config.data_files) + data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} + data_files = dl_manager.download_and_extract(data_files) splits = [] for split_name, files in data_files.items(): if isinstance(files, str): diff --git a/src/datasets/packaged_modules/text/text.py b/src/datasets/packaged_modules/text/text.py index a1f3ff5a744..f72a50139d0 100644 --- a/src/datasets/packaged_modules/text/text.py +++ b/src/datasets/packaged_modules/text/text.py @@ -1,7 +1,7 @@ import itertools from dataclasses import dataclass from io import StringIO -from typing import Optional +from typing import List, Optional import pyarrow as pa @@ -31,7 +31,10 @@ class Text(datasets.ArrowBasedBuilder): def _info(self): return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]]. If str or List[str], then the dataset returns only the 'train' split. @@ -40,7 +43,10 @@ def _split_generators(self, dl_manager): if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") dl_manager.download_config.extract_on_the_fly = True - data_files = dl_manager.download_and_extract(self.config.data_files) + data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} + data_files = dl_manager.download_and_extract(data_files) splits = [] for split_name, files in data_files.items(): if isinstance(files, str): diff --git a/src/datasets/packaged_modules/webdataset/webdataset.py b/src/datasets/packaged_modules/webdataset/webdataset.py index 571276a4cd5..fbd5fb95025 100644 --- a/src/datasets/packaged_modules/webdataset/webdataset.py +++ b/src/datasets/packaged_modules/webdataset/webdataset.py @@ -2,7 +2,7 @@ import json import re from itertools import islice -from typing import Any, Callable +from typing import Any, Callable, Dict, List, Optional import fsspec import numpy as np @@ -59,12 +59,18 @@ def _get_pipeline_from_tar(cls, tar_path, tar_iterator): def _info(self) -> datasets.DatasetInfo: return datasets.DatasetInfo() - def _split_generators(self, dl_manager): + def _available_splits(self) -> Optional[List[str]]: + return [str(split) for split in self.config.data_files] if isinstance(self.config.data_files, dict) else None + + def _split_generators(self, dl_manager, splits: Optional[List[str]] = None): """We handle string, list and dicts in datafiles""" # Download the data files if not self.config.data_files: raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") - data_files = dl_manager.download(self.config.data_files) + data_files = self.config.data_files + if splits and isinstance(data_files, dict): + data_files = {split: data_files[split] for split in splits} + data_files = dl_manager.download(data_files) splits = [] for split_name, tar_paths in data_files.items(): if isinstance(tar_paths, str): diff --git a/src/datasets/utils/info_utils.py b/src/datasets/utils/info_utils.py index d93f5b4509f..550d5219150 100644 --- a/src/datasets/utils/info_utils.py +++ b/src/datasets/utils/info_utils.py @@ -45,11 +45,11 @@ def verify_checksums(expected_checksums: Optional[dict], recorded_checksums: dic if expected_checksums is None: logger.info("Unable to verify checksums.") return - if len(set(expected_checksums) - set(recorded_checksums)) > 0: - raise ExpectedMoreDownloadedFilesError(str(set(expected_checksums) - set(recorded_checksums))) - if len(set(recorded_checksums) - set(expected_checksums)) > 0: - raise UnexpectedDownloadedFileError(str(set(recorded_checksums) - set(expected_checksums))) - bad_urls = [url for url in expected_checksums if expected_checksums[url] != recorded_checksums[url]] + bad_urls = [ + url + for url in (set(recorded_checksums) & set(expected_checksums)) + if expected_checksums[url] != recorded_checksums[url] + ] for_verification_name = " for " + verification_name if verification_name is not None else "" if len(bad_urls) > 0: raise NonMatchingChecksumError( @@ -64,13 +64,9 @@ def verify_splits(expected_splits: Optional[dict], recorded_splits: dict): if expected_splits is None: logger.info("Unable to verify splits sizes.") return - if len(set(expected_splits) - set(recorded_splits)) > 0: - raise ExpectedMoreSplitsError(str(set(expected_splits) - set(recorded_splits))) - if len(set(recorded_splits) - set(expected_splits)) > 0: - raise UnexpectedSplitsError(str(set(recorded_splits) - set(expected_splits))) bad_splits = [ {"expected": expected_splits[name], "recorded": recorded_splits[name]} - for name in expected_splits + for name in (set(recorded_splits) & set(expected_splits)) if expected_splits[name].num_examples != recorded_splits[name].num_examples ] if len(bad_splits) > 0: diff --git a/tests/test_arrow_reader.py b/tests/test_arrow_reader.py index 6987416f3a4..6320d401a27 100644 --- a/tests/test_arrow_reader.py +++ b/tests/test_arrow_reader.py @@ -1,4 +1,5 @@ import os +import posixpath import tempfile from pathlib import Path from unittest import TestCase @@ -103,8 +104,8 @@ def test_read_files(self): reader = ReaderTest(tmp_dir, info) files = [ - {"filename": os.path.join(tmp_dir, "train")}, - {"filename": os.path.join(tmp_dir, "test"), "skip": 10, "take": 10}, + {"filename": posixpath.join(tmp_dir, "train")}, + {"filename": posixpath.join(tmp_dir, "test"), "skip": 10, "take": 10}, ] dset = Dataset(**reader.read_files(files, original_instructions="train+test[10:20]")) self.assertEqual(dset.num_rows, 110) @@ -169,7 +170,7 @@ def test_make_file_instructions_basic(): assert isinstance(file_instructions, FileInstructions) assert file_instructions.num_examples == 33 assert file_instructions.file_instructions == [ - {"filename": os.path.join(prefix_path, f"{name}-train.arrow"), "skip": 0, "take": 33} + {"filename": posixpath.join(prefix_path, f"{name}-train.arrow"), "skip": 0, "take": 33} ] split_infos = [SplitInfo(name="train", num_examples=100, shard_lengths=[10] * 10)] @@ -177,10 +178,10 @@ def test_make_file_instructions_basic(): assert isinstance(file_instructions, FileInstructions) assert file_instructions.num_examples == 33 assert file_instructions.file_instructions == [ - {"filename": os.path.join(prefix_path, f"{name}-train-00000-of-00010.arrow"), "skip": 0, "take": -1}, - {"filename": os.path.join(prefix_path, f"{name}-train-00001-of-00010.arrow"), "skip": 0, "take": -1}, - {"filename": os.path.join(prefix_path, f"{name}-train-00002-of-00010.arrow"), "skip": 0, "take": -1}, - {"filename": os.path.join(prefix_path, f"{name}-train-00003-of-00010.arrow"), "skip": 0, "take": 3}, + {"filename": posixpath.join(prefix_path, f"{name}-train-00000-of-00010.arrow"), "skip": 0, "take": -1}, + {"filename": posixpath.join(prefix_path, f"{name}-train-00001-of-00010.arrow"), "skip": 0, "take": -1}, + {"filename": posixpath.join(prefix_path, f"{name}-train-00002-of-00010.arrow"), "skip": 0, "take": -1}, + {"filename": posixpath.join(prefix_path, f"{name}-train-00003-of-00010.arrow"), "skip": 0, "take": 3}, ] @@ -217,7 +218,7 @@ def test_make_file_instructions(split_name, instruction, shard_lengths, read_ran if not isinstance(shard_lengths, list): assert file_instructions.file_instructions == [ { - "filename": os.path.join(prefix_path, f"{name}-{split_name}.arrow"), + "filename": posixpath.join(prefix_path, f"{name}-{split_name}.arrow"), "skip": read_range[0], "take": read_range[1] - read_range[0], } @@ -226,7 +227,9 @@ def test_make_file_instructions(split_name, instruction, shard_lengths, read_ran file_instructions_list = [] shard_offset = 0 for i, shard_length in enumerate(shard_lengths): - filename = os.path.join(prefix_path, f"{name}-{split_name}-{i:05d}-of-{len(shard_lengths):05d}.arrow") + filename = posixpath.join( + prefix_path, f"{name}-{split_name}-{i:05d}-of-{len(shard_lengths):05d}.arrow" + ) if shard_offset <= read_range[0] < shard_offset + shard_length: file_instructions_list.append( { diff --git a/tests/test_download_manager.py b/tests/test_download_manager.py index 08eb77366c1..2b09741e3f2 100644 --- a/tests/test_download_manager.py +++ b/tests/test_download_manager.py @@ -131,7 +131,6 @@ def test_download_manager_delete_extracted_files(xz_file): assert extracted_path == dl_manager.extracted_paths[xz_file] extracted_path = Path(extracted_path) parts = extracted_path.parts - # import pdb; pdb.set_trace() assert parts[-1] == hash_url_to_filename(str(xz_file), etag=None) assert parts[-2] == extracted_subdir assert extracted_path.exists() diff --git a/tests/test_load.py b/tests/test_load.py index a532452eb4c..1fbc3c003a5 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1048,6 +1048,70 @@ def test_load_dataset_with_unsupported_extensions(text_dir_with_unsupported_exte assert ds.num_rows == 4 +def test_load_dataset_specific_splits(data_dir): + with tempfile.TemporaryDirectory() as tmp_dir: + with load_dataset(data_dir, split="train", cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + processed_dataset_dir = load_dataset_builder(data_dir, cache_dir=tmp_dir).cache_dir + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files) + + with load_dataset(data_dir, split="test", cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith(("train", "test")) for arrow_file in arrow_files) + + with pytest.raises(ValueError): + load_dataset(data_dir, split="non-existing-split", cache_dir=tmp_dir) + + +def test_load_dataset_specific_splits_then_full(data_dir): + with tempfile.TemporaryDirectory() as tmp_dir: + with load_dataset(data_dir, split="train", cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + processed_dataset_dir = load_dataset_builder(data_dir, cache_dir=tmp_dir).cache_dir + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files) + + with load_dataset(data_dir, cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, DatasetDict) + assert len(dataset) > 0 + assert "train" in dataset + assert "test" in dataset + dataset_splits = list(dataset) + + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith(tuple(dataset_splits)) for arrow_file in arrow_files) + + +@pytest.mark.integration +def test_loading_from_dataset_from_hub_specific_splits(): + with tempfile.TemporaryDirectory() as tmp_dir: + with load_dataset(SAMPLE_DATASET_IDENTIFIER2, split="train", cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + processed_dataset_dir = load_dataset_builder(SAMPLE_DATASET_IDENTIFIER2, cache_dir=tmp_dir).cache_dir + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith("train") for arrow_file in arrow_files) + + with load_dataset(SAMPLE_DATASET_IDENTIFIER2, split="test", cache_dir=tmp_dir) as dataset: + assert isinstance(dataset, Dataset) + assert len(dataset) > 0 + + arrow_files = Path(processed_dataset_dir).glob("*.arrow") + assert all(arrow_file.name.split("-", 1)[1].startswith(("train", "test")) for arrow_file in arrow_files) + + with pytest.raises(ValueError): + load_dataset(SAMPLE_DATASET_IDENTIFIER2, split="non-existing-split", cache_dir=tmp_dir) + + @pytest.mark.integration def test_loading_from_the_datasets_hub_with_token(): true_request = requests.Session().request