diff --git a/src/datasets/load.py b/src/datasets/load.py index bc2b0e679b6..a057c246ea1 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -55,7 +55,7 @@ ) from .dataset_dict import DatasetDict, IterableDatasetDict from .download.download_config import DownloadConfig -from .download.download_manager import DownloadMode +from .download.download_manager import DownloadManager, DownloadMode from .download.streaming_download_manager import StreamingDownloadManager, xbasename, xglob, xjoin from .exceptions import DataFilesNotFoundError, DatasetNotFoundError from .features import Features @@ -1193,6 +1193,7 @@ def load_dataset( download_config: Optional[DownloadConfig] = None, download_mode: Optional[Union[DownloadMode, str]] = None, verification_mode: Optional[Union[VerificationMode, str]] = None, + dl_manager: Optional[DownloadManager] = None, keep_in_memory: Optional[bool] = None, save_infos: bool = False, revision: Optional[Union[str, Version]] = None, @@ -1275,6 +1276,8 @@ def load_dataset( Verification mode determining the checks to run on the downloaded/processed dataset information (checksums/size/splits/...). + dl_manager (`DownloadManager`, *optional*): + Specific `DownloadManger` to use. keep_in_memory (`bool`, defaults to `None`): Whether to copy the dataset in-memory. If `None`, the dataset will not be copied in-memory unless explicitly enabled by setting `datasets.config.IN_MEMORY_MAX_SIZE` to @@ -1408,11 +1411,33 @@ def load_dataset( if streaming: return builder_instance.as_streaming_dataset(split=split) + if dl_manager is None: + if download_config is None: + download_config = DownloadConfig( + cache_dir=builder_instance._cache_downloaded_dir, + force_download=download_mode == DownloadMode.FORCE_REDOWNLOAD, + force_extract=download_mode == DownloadMode.FORCE_REDOWNLOAD, + use_etag=False, + num_proc=num_proc, + token=builder_instance.token, + storage_options=builder_instance.storage_options, + ) # We don't use etag for data files to speed up the process + + dl_manager = DownloadManager( + dataset_name=builder_instance.dataset_name, + download_config=download_config, + data_dir=builder_instance.config.data_dir, + record_checksums=( + builder_instance._record_infos or verification_mode == VerificationMode.ALL_CHECKS + ), + ) + # Download and prepare data builder_instance.download_and_prepare( download_config=download_config, download_mode=download_mode, verification_mode=verification_mode, + dl_manager=dl_manager, num_proc=num_proc, storage_options=storage_options, )