diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 26965fb6a..5fcacf784 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -95,10 +95,10 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str: return os.path.join(cache_dir, name.lstrip("/")) -def _wait_for_file_to_exist(remote_filepath: str, sleep_time: int = 2) -> Any: +def _wait_for_file_to_exist(remote_filepath: str, sleep_time: int = 2, storage_options: Dict[str, Any] = {}) -> Any: """Wait until the file exists.""" file_exists = False - fs_provider = _get_fs_provider(remote_filepath) + fs_provider = _get_fs_provider(remote_filepath, storage_options) while not file_exists: file_exists = fs_provider.exists(remote_filepath) if not file_exists: @@ -121,7 +121,9 @@ def _wait_for_disk_usage_higher_than_threshold(input_dir: str, threshold_in_gb: # 1. `queue_in`: A queue that receives the (index, paths) from where the data is to be downloaded. # 2. `queue_out`: A queue that sends the index after the files have been downloaded and ready to be used. # -def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None: +def _download_data_target( + input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue, storage_options: Dict[str, Any] = {} +) -> None: """Download data from a remote directory to a cache directory to optimise reading.""" fs_provider = None @@ -163,7 +165,7 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue dirpath = os.path.dirname(local_path) os.makedirs(dirpath, exist_ok=True) if fs_provider is None: - fs_provider = _get_fs_provider(input_dir.url) + fs_provider = _get_fs_provider(input_dir.url, storage_options) fs_provider.download_file(path, local_path) elif os.path.isfile(path): @@ -221,12 +223,14 @@ def keep_path(path: str) -> bool: # 2. `remove_queue`: After uploading, the file is sent to the remove queue, # so it can be deleted from the cache directory. # -def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir) -> None: +def _upload_fn( + upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir, storage_options: Dict[str, Any] = {} +) -> None: """Upload optimised chunks from a local to remote dataset directory.""" obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path) if obj.scheme in _SUPPORTED_PROVIDERS: - fs_provider = _get_fs_provider(output_dir.url) + fs_provider = _get_fs_provider(output_dir.url, storage_options) while True: data: Optional[Union[str, Tuple[str, str]]] = upload_queue.get() @@ -465,6 +469,7 @@ def __init__( checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = None, checkpoint_next_index: Optional[int] = None, item_loader: Optional[BaseItemLoader] = None, + storage_options: Dict[str, Any] = {}, ) -> None: """The BaseWorker is responsible to process the user data.""" self.worker_index = worker_index @@ -499,6 +504,7 @@ def __init__( self.use_checkpoint: bool = use_checkpoint self.checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = checkpoint_chunks_info self.checkpoint_next_index: Optional[int] = checkpoint_next_index + self.storage_options = storage_options def run(self) -> None: try: @@ -695,6 +701,7 @@ def _start_downloaders(self) -> None: self.cache_data_dir, to_download_queue, self.ready_to_process_queue, + self.storage_options, ), ) p.start() @@ -734,6 +741,7 @@ def _start_uploaders(self) -> None: self.remove_queue, self.cache_chunks_dir, self.output_dir, + self.storage_options, ), ) p.start() @@ -848,8 +856,9 @@ def prepare_item(self, *args: Any, **kwargs: Any) -> Any: """ pass - def __init__(self) -> None: + def __init__(self, storage_options: Dict[str, Any] = {}) -> None: self._name: Optional[str] = None + self.storage_options = storage_options def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Result: return _Result(size=size) @@ -862,8 +871,9 @@ def __init__( chunk_bytes: Optional[Union[int, str]] = None, compression: Optional[str] = None, encryption: Optional[Encryption] = None, + storage_options: Dict[str, Any] = {}, ): - super().__init__() + super().__init__(storage_options) if chunk_size is not None and chunk_bytes is not None: raise ValueError("Either one of the `chunk_size` or the `chunk_bytes` need to be provided.") @@ -938,7 +948,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra local_filepath = os.path.join(cache_dir, _INDEX_FILENAME) if obj.scheme in _SUPPORTED_PROVIDERS: - fs_provider = _get_fs_provider(output_dir.url) + fs_provider = _get_fs_provider(output_dir.url, self.storage_options) fs_provider.upload_file( local_filepath, os.path.join(output_dir.url, os.path.basename(local_filepath)), @@ -960,7 +970,8 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}") node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath)) if obj.scheme in _SUPPORTED_PROVIDERS: - fs_provider = _get_fs_provider(remote_filepath) + _wait_for_file_to_exist(remote_filepath, storage_options=self.storage_options) + fs_provider = _get_fs_provider(remote_filepath, self.storage_options) fs_provider.download_file(remote_filepath, node_index_filepath) elif output_dir.path and os.path.isdir(output_dir.path): shutil.copyfile(remote_filepath, node_index_filepath) @@ -1002,6 +1013,7 @@ def __init__( use_checkpoint: bool = False, item_loader: Optional[BaseItemLoader] = None, start_method: Optional[str] = None, + storage_options: Dict[str, Any] = {}, ): """Provides an efficient way to process data across multiple machine into chunks to make training faster. @@ -1026,6 +1038,7 @@ def __init__( the format in which the data is stored and optimized for loading. start_method: The start method used by python multiprocessing package. Default to spawn unless running inside an interactive shell like Ipython. + storage_options: Storage options for the cloud provider. """ # spawn doesn't work in IPython @@ -1060,6 +1073,7 @@ def __init__( self.checkpoint_chunks_info: Optional[List[List[Dict[str, Any]]]] = None self.checkpoint_next_index: Optional[List[int]] = None self.item_loader = item_loader + self.storage_options = storage_options self.state_dict = state_dict or {rank: 0 for rank in range(self.num_workers)} @@ -1284,6 +1298,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L self.checkpoint_chunks_info[worker_idx] if self.checkpoint_chunks_info else None, self.checkpoint_next_index[worker_idx] if self.checkpoint_next_index else None, self.item_loader, + self.storage_options, ) worker.start() workers.append(worker) @@ -1340,7 +1355,7 @@ def _cleanup_checkpoints(self) -> None: prefix = self.output_dir.url.rstrip("/") + "/" checkpoint_prefix = os.path.join(prefix, ".checkpoints") - fs_provider = _get_fs_provider(self.output_dir.url) + fs_provider = _get_fs_provider(self.output_dir.url, self.storage_options) fs_provider.delete_file_or_directory(checkpoint_prefix) def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: @@ -1370,7 +1385,7 @@ def _save_current_config(self, workers_user_items: List[List[Any]]) -> None: if obj.scheme not in _SUPPORTED_PROVIDERS: not_supported_provider(self.output_dir.url) - fs_provider = _get_fs_provider(self.output_dir.url) + fs_provider = _get_fs_provider(self.output_dir.url, self.storage_options) prefix = self.output_dir.url.rstrip("/") + "/" + ".checkpoints/" @@ -1441,7 +1456,7 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None: # download all the checkpoint files in tempdir and read them with tempfile.TemporaryDirectory() as temp_dir: - fs_provider = _get_fs_provider(self.output_dir.url) + fs_provider = _get_fs_provider(self.output_dir.url, self.storage_options) saved_file_dir = fs_provider.download_directory(prefix, temp_dir) if not os.path.exists(os.path.join(saved_file_dir, "config.json")): diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index f2f9dd1e9..02c478b90 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -106,8 +106,13 @@ def _get_default_num_workers() -> int: class LambdaMapRecipe(MapRecipe): """Recipe for `map`.""" - def __init__(self, fn: Callable[[str, Any], None], inputs: Union[Sequence[Any], StreamingDataLoader]): - super().__init__() + def __init__( + self, + fn: Callable[[str, Any], None], + inputs: Union[Sequence[Any], StreamingDataLoader], + storage_options: Dict[str, Any] = {}, + ): + super().__init__(storage_options) self._fn = fn self._inputs = inputs self._device: Optional[str] = None @@ -160,8 +165,15 @@ def __init__( compression: Optional[str], encryption: Optional[Encryption] = None, existing_index: Optional[Dict[str, Any]] = None, + storage_options: Dict[str, Any] = {}, ): - super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, encryption=encryption) + super().__init__( + chunk_size=chunk_size, + chunk_bytes=chunk_bytes, + compression=compression, + encryption=encryption, + storage_options=storage_options, + ) self._fn = fn self._inputs = inputs self.is_generator = False @@ -210,6 +222,7 @@ def map( batch_size: Optional[int] = None, start_method: Optional[str] = None, optimize_dns: Optional[bool] = None, + storage_options: Dict[str, Any] = {}, ) -> None: """Maps a callable over a collection of inputs, possibly in a distributed way. @@ -234,6 +247,7 @@ def map( start_method: The start method used by python multiprocessing package. Default to spawn unless running inside an interactive shell like Ipython. optimize_dns: Whether the optimized dns should be used. + storage_options: Storage options for the cloud provider. """ _check_version_and_prompt_upgrade(__version__) @@ -273,7 +287,7 @@ def map( ) if error_when_not_empty: - _assert_dir_is_empty(_output_dir) + _assert_dir_is_empty(_output_dir, storage_options=storage_options) if not isinstance(inputs, StreamingDataLoader): input_dir = input_dir or _get_input_dir(inputs) @@ -298,10 +312,11 @@ def map( weights=weights, reader=reader, start_method=start_method, + storage_options=storage_options, ) with optimize_dns_context(optimize_dns if optimize_dns is not None else False): - return data_processor.run(LambdaMapRecipe(fn, inputs)) + return data_processor.run(LambdaMapRecipe(fn, inputs, storage_options=storage_options)) return _execute( f"litdata-map-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}", num_nodes, @@ -351,6 +366,7 @@ def optimize( item_loader: Optional[BaseItemLoader] = None, start_method: Optional[str] = None, optimize_dns: Optional[bool] = None, + storage_options: Dict[str, Any] = {}, ) -> None: """This function converts a dataset into chunks, possibly in a distributed way. @@ -386,6 +402,7 @@ def optimize( start_method: The start method used by python multiprocessing package. Default to spawn unless running inside an interactive shell like Ipython. optimize_dns: Whether the optimized dns should be used. + storage_options: Storage options for the cloud provider. """ _check_version_and_prompt_upgrade(__version__) @@ -440,7 +457,9 @@ def optimize( "\n HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`." ) - _assert_dir_has_index_file(_output_dir, mode=mode, use_checkpoint=use_checkpoint) + _assert_dir_has_index_file( + _output_dir, mode=mode, use_checkpoint=use_checkpoint, storage_options=storage_options + ) if not isinstance(inputs, StreamingDataLoader): resolved_dir = _resolve_dir(input_dir or _get_input_dir(inputs)) @@ -456,7 +475,9 @@ def optimize( num_workers = num_workers or _get_default_num_workers() state_dict = {rank: 0 for rank in range(num_workers)} - existing_index_file_content = read_index_file_content(_output_dir) if mode == "append" else None + existing_index_file_content = ( + read_index_file_content(_output_dir, storage_options) if mode == "append" else None + ) if existing_index_file_content is not None: for chunk in existing_index_file_content["chunks"]: @@ -478,6 +499,7 @@ def optimize( use_checkpoint=use_checkpoint, item_loader=item_loader, start_method=start_method, + storage_options=storage_options, ) with optimize_dns_context(optimize_dns if optimize_dns is not None else False): @@ -490,6 +512,7 @@ def optimize( compression=compression, encryption=encryption, existing_index=existing_index_file_content, + storage_options=storage_options, ) ) return None @@ -558,14 +581,19 @@ class CopyInfo: new_filename: str -def merge_datasets(input_dirs: List[str], output_dir: str, max_workers: Optional[int] = os.cpu_count()) -> None: +def merge_datasets( + input_dirs: List[str], + output_dir: str, + max_workers: Optional[int] = os.cpu_count(), + storage_options: Dict[str, Any] = {}, +) -> None: """Enables to merge multiple existing optimized datasets into a single optimized dataset. Args: input_dirs: A list of directories pointing to the existing optimized datasets. output_dir: The directory where the merged dataset would be stored. max_workers: Number of workers for multithreading - + storage_options: Storage options for the cloud provider. """ if len(input_dirs) == 0: raise ValueError("The input directories needs to be defined.") @@ -580,12 +608,12 @@ def merge_datasets(input_dirs: List[str], output_dir: str, max_workers: Optional if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs): raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.") - input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs] + input_dirs_file_content = [read_index_file_content(input_dir, storage_options) for input_dir in resolved_input_dirs] if any(file_content is None for file_content in input_dirs_file_content): raise ValueError("One of the provided input_dir doesn't have an index file.") - output_dir_file_content = read_index_file_content(resolved_output_dir) + output_dir_file_content = read_index_file_content(resolved_output_dir, storage_options) if output_dir_file_content is not None: raise ValueError("The output_dir already contains an optimized dataset") @@ -622,16 +650,16 @@ def merge_datasets(input_dirs: List[str], output_dir: str, max_workers: Optional with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures: List[concurrent.futures.Future] = [] for copy_info in copy_infos: - future = executor.submit(_apply_copy, copy_info, resolved_output_dir) + future = executor.submit(_apply_copy, copy_info, resolved_output_dir, storage_options) futures.append(future) for future in _tqdm(concurrent.futures.as_completed(futures), total=len(futures)): future.result() - _save_index(index_json, resolved_output_dir) + _save_index(index_json, resolved_output_dir, storage_options) -def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: +def _apply_copy(copy_info: CopyInfo, output_dir: Dir, storage_options: Dict[str, Any] = {}) -> None: if output_dir.url is None and copy_info.input_dir.url is None: assert copy_info.input_dir.path assert output_dir.path @@ -644,13 +672,13 @@ def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: input_filepath = os.path.join(copy_info.input_dir.url, copy_info.old_filename) output_filepath = os.path.join(output_dir.url, copy_info.new_filename) - fs_provider = _get_fs_provider(output_dir.url) + fs_provider = _get_fs_provider(output_dir.url, storage_options) fs_provider.copy(input_filepath, output_filepath) else: raise NotImplementedError -def _save_index(index_json: Dict, output_dir: Dir) -> None: +def _save_index(index_json: Dict, output_dir: Dir, storage_options: Dict[str, Any] = {}) -> None: if output_dir.url is None: assert output_dir.path with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f: @@ -663,5 +691,5 @@ def _save_index(index_json: Dict, output_dir: Dir) -> None: remote_path = os.path.join(output_dir.url, _INDEX_FILENAME) - fs_provider = _get_fs_provider(output_dir.url) + fs_provider = _get_fs_provider(output_dir.url, storage_options) fs_provider.upload_file(f.name, remote_path) diff --git a/src/litdata/processing/utilities.py b/src/litdata/processing/utilities.py index 66dc5e150..8788fd669 100644 --- a/src/litdata/processing/utilities.py +++ b/src/litdata/processing/utilities.py @@ -200,7 +200,7 @@ def _get_work_dir() -> str: return f"s3://{bucket_name}/projects/{project_id}/lightningapps/{app_id}/artifacts/{work_id}/content/" -def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]: +def read_index_file_content(output_dir: Dir, storage_options: Dict[str, Any] = {}) -> Optional[Dict[str, Any]]: """Read the index file content.""" if not isinstance(output_dir, Dir): raise ValueError("The provided output_dir should be a Dir object.") @@ -221,7 +221,7 @@ def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]: if obj.scheme not in _SUPPORTED_PROVIDERS: not_supported_provider(output_dir.url) - fs_provider = _get_fs_provider(output_dir.url) + fs_provider = _get_fs_provider(output_dir.url, storage_options) prefix = output_dir.url.rstrip("/") + "/" diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 9bdf238bd..6233f6099 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from pathlib import Path from time import sleep -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union from urllib import parse from litdata.constants import _LIGHTNING_SDK_AVAILABLE, _SUPPORTED_PROVIDERS @@ -229,7 +229,9 @@ def _resolve_datasets(dir_path: str) -> Dir: ) -def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool = False) -> None: +def _assert_dir_is_empty( + output_dir: Dir, append: bool = False, overwrite: bool = False, storage_options: Dict[str, Any] = {} +) -> None: if not isinstance(output_dir, Dir): raise ValueError("The provided output_dir isn't a `Dir` Object.") @@ -241,7 +243,7 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool if obj.scheme not in _SUPPORTED_PROVIDERS: not_supported_provider(output_dir.url) - fs_provider = _get_fs_provider(output_dir.url) + fs_provider = _get_fs_provider(output_dir.url, storage_options) is_empty = fs_provider.is_empty(output_dir.url) @@ -255,7 +257,10 @@ def _assert_dir_is_empty(output_dir: Dir, append: bool = False, overwrite: bool def _assert_dir_has_index_file( - output_dir: Dir, mode: Optional[Literal["append", "overwrite"]] = None, use_checkpoint: bool = False + output_dir: Dir, + mode: Optional[Literal["append", "overwrite"]] = None, + use_checkpoint: bool = False, + storage_options: Dict[str, Any] = {}, ) -> None: if mode is not None and mode not in ["append", "overwrite"]: raise ValueError(f"The provided `mode` should be either `append` or `overwrite`. Found {mode}.") @@ -302,7 +307,7 @@ def _assert_dir_has_index_file( if obj.scheme not in _SUPPORTED_PROVIDERS: not_supported_provider(output_dir.url) - fs_provider = _get_fs_provider(output_dir.url) + fs_provider = _get_fs_provider(output_dir.url, storage_options) prefix = output_dir.url.rstrip("/") + "/" @@ -326,9 +331,7 @@ def _assert_dir_has_index_file( "\n HINT: If you want to append/overwrite to the existing dataset, use `mode='append | overwrite'`." ) - # all the files (including the index file in overwrite mode) - fs_provider = _get_fs_provider(output_dir.url) - + # delete all the files (including the index file in overwrite mode) if mode == "overwrite" or (mode is None and not use_checkpoint): fs_provider.delete_file_or_directory(output_dir.url)