Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str:
return os.path.join(cache_dir, name.lstrip("/"))


def _wait_for_file_to_exist(remote_filepath: str, sleep_time: int = 2) -> Any:
def _wait_for_file_to_exist(remote_filepath: str, sleep_time: int = 2, storage_options: Dict[str, Any] = {}) -> Any:
"""Wait until the file exists."""
file_exists = False
fs_provider = _get_fs_provider(remote_filepath)
fs_provider = _get_fs_provider(remote_filepath, storage_options)
while not file_exists:
file_exists = fs_provider.exists(remote_filepath)
if not file_exists:
Expand All @@ -121,7 +121,9 @@ def _wait_for_disk_usage_higher_than_threshold(input_dir: str, threshold_in_gb:
# 1. `queue_in`: A queue that receives the (index, paths) from where the data is to be downloaded.
# 2. `queue_out`: A queue that sends the index after the files have been downloaded and ready to be used.
#
def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None:
def _download_data_target(
input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue, storage_options: Dict[str, Any] = {}
) -> None:
"""Download data from a remote directory to a cache directory to optimise reading."""
fs_provider = None

Expand Down Expand Up @@ -163,7 +165,7 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
dirpath = os.path.dirname(local_path)
os.makedirs(dirpath, exist_ok=True)
if fs_provider is None:
fs_provider = _get_fs_provider(input_dir.url)
fs_provider = _get_fs_provider(input_dir.url, storage_options)
fs_provider.download_file(path, local_path)

elif os.path.isfile(path):
Expand Down Expand Up @@ -221,12 +223,14 @@ def keep_path(path: str) -> bool:
# 2. `remove_queue`: After uploading, the file is sent to the remove queue,
# so it can be deleted from the cache directory.
#
def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir) -> None:
def _upload_fn(
upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir, storage_options: Dict[str, Any] = {}
) -> None:
"""Upload optimised chunks from a local to remote dataset directory."""
obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path)

if obj.scheme in _SUPPORTED_PROVIDERS:
fs_provider = _get_fs_provider(output_dir.url)
fs_provider = _get_fs_provider(output_dir.url, storage_options)

while True:
data: Optional[Union[str, Tuple[str, str]]] = upload_queue.get()
Expand Down Expand Up @@ -465,6 +469,7 @@ def __init__(
checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = None,
checkpoint_next_index: Optional[int] = None,
item_loader: Optional[BaseItemLoader] = None,
storage_options: Dict[str, Any] = {},
) -> None:
"""The BaseWorker is responsible to process the user data."""
self.worker_index = worker_index
Expand Down Expand Up @@ -499,6 +504,7 @@ def __init__(
self.use_checkpoint: bool = use_checkpoint
self.checkpoint_chunks_info: Optional[List[Dict[str, Any]]] = checkpoint_chunks_info
self.checkpoint_next_index: Optional[int] = checkpoint_next_index
self.storage_options = storage_options

def run(self) -> None:
try:
Expand Down Expand Up @@ -695,6 +701,7 @@ def _start_downloaders(self) -> None:
self.cache_data_dir,
to_download_queue,
self.ready_to_process_queue,
self.storage_options,
),
)
p.start()
Expand Down Expand Up @@ -734,6 +741,7 @@ def _start_uploaders(self) -> None:
self.remove_queue,
self.cache_chunks_dir,
self.output_dir,
self.storage_options,
),
)
p.start()
Expand Down Expand Up @@ -848,8 +856,9 @@ def prepare_item(self, *args: Any, **kwargs: Any) -> Any:
"""
pass

def __init__(self) -> None:
def __init__(self, storage_options: Dict[str, Any] = {}) -> None:
self._name: Optional[str] = None
self.storage_options = storage_options

def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Result:
return _Result(size=size)
Expand All @@ -862,8 +871,9 @@ def __init__(
chunk_bytes: Optional[Union[int, str]] = None,
compression: Optional[str] = None,
encryption: Optional[Encryption] = None,
storage_options: Dict[str, Any] = {},
):
super().__init__()
super().__init__(storage_options)
if chunk_size is not None and chunk_bytes is not None:
raise ValueError("Either one of the `chunk_size` or the `chunk_bytes` need to be provided.")

Expand Down Expand Up @@ -938,7 +948,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
local_filepath = os.path.join(cache_dir, _INDEX_FILENAME)

if obj.scheme in _SUPPORTED_PROVIDERS:
fs_provider = _get_fs_provider(output_dir.url)
fs_provider = _get_fs_provider(output_dir.url, self.storage_options)
fs_provider.upload_file(
local_filepath,
os.path.join(output_dir.url, os.path.basename(local_filepath)),
Expand All @@ -960,7 +970,8 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}")
node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath))
if obj.scheme in _SUPPORTED_PROVIDERS:
fs_provider = _get_fs_provider(remote_filepath)
_wait_for_file_to_exist(remote_filepath, storage_options=self.storage_options)
fs_provider = _get_fs_provider(remote_filepath, self.storage_options)
fs_provider.download_file(remote_filepath, node_index_filepath)
elif output_dir.path and os.path.isdir(output_dir.path):
shutil.copyfile(remote_filepath, node_index_filepath)
Expand Down Expand Up @@ -1002,6 +1013,7 @@ def __init__(
use_checkpoint: bool = False,
item_loader: Optional[BaseItemLoader] = None,
start_method: Optional[str] = None,
storage_options: Dict[str, Any] = {},
):
"""Provides an efficient way to process data across multiple machine into chunks to make training faster.

Expand All @@ -1026,6 +1038,7 @@ def __init__(
the format in which the data is stored and optimized for loading.
start_method: The start method used by python multiprocessing package. Default to spawn unless running
inside an interactive shell like Ipython.
storage_options: Storage options for the cloud provider.

"""
# spawn doesn't work in IPython
Expand Down Expand Up @@ -1060,6 +1073,7 @@ def __init__(
self.checkpoint_chunks_info: Optional[List[List[Dict[str, Any]]]] = None
self.checkpoint_next_index: Optional[List[int]] = None
self.item_loader = item_loader
self.storage_options = storage_options

self.state_dict = state_dict or {rank: 0 for rank in range(self.num_workers)}

Expand Down Expand Up @@ -1284,6 +1298,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L
self.checkpoint_chunks_info[worker_idx] if self.checkpoint_chunks_info else None,
self.checkpoint_next_index[worker_idx] if self.checkpoint_next_index else None,
self.item_loader,
self.storage_options,
)
worker.start()
workers.append(worker)
Expand Down Expand Up @@ -1340,7 +1355,7 @@ def _cleanup_checkpoints(self) -> None:
prefix = self.output_dir.url.rstrip("/") + "/"
checkpoint_prefix = os.path.join(prefix, ".checkpoints")

fs_provider = _get_fs_provider(self.output_dir.url)
fs_provider = _get_fs_provider(self.output_dir.url, self.storage_options)
fs_provider.delete_file_or_directory(checkpoint_prefix)

def _save_current_config(self, workers_user_items: List[List[Any]]) -> None:
Expand Down Expand Up @@ -1370,7 +1385,7 @@ def _save_current_config(self, workers_user_items: List[List[Any]]) -> None:
if obj.scheme not in _SUPPORTED_PROVIDERS:
not_supported_provider(self.output_dir.url)

fs_provider = _get_fs_provider(self.output_dir.url)
fs_provider = _get_fs_provider(self.output_dir.url, self.storage_options)

prefix = self.output_dir.url.rstrip("/") + "/" + ".checkpoints/"

Expand Down Expand Up @@ -1441,7 +1456,7 @@ def _load_checkpoint_config(self, workers_user_items: List[List[Any]]) -> None:

# download all the checkpoint files in tempdir and read them
with tempfile.TemporaryDirectory() as temp_dir:
fs_provider = _get_fs_provider(self.output_dir.url)
fs_provider = _get_fs_provider(self.output_dir.url, self.storage_options)
saved_file_dir = fs_provider.download_directory(prefix, temp_dir)

if not os.path.exists(os.path.join(saved_file_dir, "config.json")):
Expand Down
62 changes: 45 additions & 17 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,13 @@ def _get_default_num_workers() -> int:
class LambdaMapRecipe(MapRecipe):
"""Recipe for `map`."""

def __init__(self, fn: Callable[[str, Any], None], inputs: Union[Sequence[Any], StreamingDataLoader]):
super().__init__()
def __init__(
self,
fn: Callable[[str, Any], None],
inputs: Union[Sequence[Any], StreamingDataLoader],
storage_options: Dict[str, Any] = {},
):
super().__init__(storage_options)
self._fn = fn
self._inputs = inputs
self._device: Optional[str] = None
Expand Down Expand Up @@ -160,8 +165,15 @@ def __init__(
compression: Optional[str],
encryption: Optional[Encryption] = None,
existing_index: Optional[Dict[str, Any]] = None,
storage_options: Dict[str, Any] = {},
):
super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, encryption=encryption)
super().__init__(
chunk_size=chunk_size,
chunk_bytes=chunk_bytes,
compression=compression,
encryption=encryption,
storage_options=storage_options,
)
self._fn = fn
self._inputs = inputs
self.is_generator = False
Expand Down Expand Up @@ -210,6 +222,7 @@ def map(
batch_size: Optional[int] = None,
start_method: Optional[str] = None,
optimize_dns: Optional[bool] = None,
storage_options: Dict[str, Any] = {},
) -> None:
"""Maps a callable over a collection of inputs, possibly in a distributed way.

Expand All @@ -234,6 +247,7 @@ def map(
start_method: The start method used by python multiprocessing package. Default to spawn unless running
inside an interactive shell like Ipython.
optimize_dns: Whether the optimized dns should be used.
storage_options: Storage options for the cloud provider.
"""
_check_version_and_prompt_upgrade(__version__)

Expand Down Expand Up @@ -273,7 +287,7 @@ def map(
)

if error_when_not_empty:
_assert_dir_is_empty(_output_dir)
_assert_dir_is_empty(_output_dir, storage_options=storage_options)

if not isinstance(inputs, StreamingDataLoader):
input_dir = input_dir or _get_input_dir(inputs)
Expand All @@ -298,10 +312,11 @@ def map(
weights=weights,
reader=reader,
start_method=start_method,
storage_options=storage_options,
)

with optimize_dns_context(optimize_dns if optimize_dns is not None else False):
return data_processor.run(LambdaMapRecipe(fn, inputs))
return data_processor.run(LambdaMapRecipe(fn, inputs, storage_options=storage_options))
return _execute(
f"litdata-map-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
num_nodes,
Expand Down Expand Up @@ -351,6 +366,7 @@ def optimize(
item_loader: Optional[BaseItemLoader] = None,
start_method: Optional[str] = None,
optimize_dns: Optional[bool] = None,
storage_options: Dict[str, Any] = {},
) -> None:
"""This function converts a dataset into chunks, possibly in a distributed way.

Expand Down Expand Up @@ -386,6 +402,7 @@ def optimize(
start_method: The start method used by python multiprocessing package. Default to spawn unless running
inside an interactive shell like Ipython.
optimize_dns: Whether the optimized dns should be used.
storage_options: Storage options for the cloud provider.
"""
_check_version_and_prompt_upgrade(__version__)

Expand Down Expand Up @@ -440,7 +457,9 @@ def optimize(
"\n HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
)

_assert_dir_has_index_file(_output_dir, mode=mode, use_checkpoint=use_checkpoint)
_assert_dir_has_index_file(
_output_dir, mode=mode, use_checkpoint=use_checkpoint, storage_options=storage_options
)

if not isinstance(inputs, StreamingDataLoader):
resolved_dir = _resolve_dir(input_dir or _get_input_dir(inputs))
Expand All @@ -456,7 +475,9 @@ def optimize(
num_workers = num_workers or _get_default_num_workers()
state_dict = {rank: 0 for rank in range(num_workers)}

existing_index_file_content = read_index_file_content(_output_dir) if mode == "append" else None
existing_index_file_content = (
read_index_file_content(_output_dir, storage_options) if mode == "append" else None
)

if existing_index_file_content is not None:
for chunk in existing_index_file_content["chunks"]:
Expand All @@ -478,6 +499,7 @@ def optimize(
use_checkpoint=use_checkpoint,
item_loader=item_loader,
start_method=start_method,
storage_options=storage_options,
)

with optimize_dns_context(optimize_dns if optimize_dns is not None else False):
Expand All @@ -490,6 +512,7 @@ def optimize(
compression=compression,
encryption=encryption,
existing_index=existing_index_file_content,
storage_options=storage_options,
)
)
return None
Expand Down Expand Up @@ -558,14 +581,19 @@ class CopyInfo:
new_filename: str


def merge_datasets(input_dirs: List[str], output_dir: str, max_workers: Optional[int] = os.cpu_count()) -> None:
def merge_datasets(
input_dirs: List[str],
output_dir: str,
max_workers: Optional[int] = os.cpu_count(),
storage_options: Dict[str, Any] = {},
) -> None:
"""Enables to merge multiple existing optimized datasets into a single optimized dataset.

Args:
input_dirs: A list of directories pointing to the existing optimized datasets.
output_dir: The directory where the merged dataset would be stored.
max_workers: Number of workers for multithreading

storage_options: Storage options for the cloud provider.
"""
if len(input_dirs) == 0:
raise ValueError("The input directories needs to be defined.")
Expand All @@ -580,12 +608,12 @@ def merge_datasets(input_dirs: List[str], output_dir: str, max_workers: Optional
if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs):
raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.")

input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs]
input_dirs_file_content = [read_index_file_content(input_dir, storage_options) for input_dir in resolved_input_dirs]

if any(file_content is None for file_content in input_dirs_file_content):
raise ValueError("One of the provided input_dir doesn't have an index file.")

output_dir_file_content = read_index_file_content(resolved_output_dir)
output_dir_file_content = read_index_file_content(resolved_output_dir, storage_options)

if output_dir_file_content is not None:
raise ValueError("The output_dir already contains an optimized dataset")
Expand Down Expand Up @@ -622,16 +650,16 @@ def merge_datasets(input_dirs: List[str], output_dir: str, max_workers: Optional
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures: List[concurrent.futures.Future] = []
for copy_info in copy_infos:
future = executor.submit(_apply_copy, copy_info, resolved_output_dir)
future = executor.submit(_apply_copy, copy_info, resolved_output_dir, storage_options)
futures.append(future)

for future in _tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
future.result()

_save_index(index_json, resolved_output_dir)
_save_index(index_json, resolved_output_dir, storage_options)


def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None:
def _apply_copy(copy_info: CopyInfo, output_dir: Dir, storage_options: Dict[str, Any] = {}) -> None:
if output_dir.url is None and copy_info.input_dir.url is None:
assert copy_info.input_dir.path
assert output_dir.path
Expand All @@ -644,13 +672,13 @@ def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None:
input_filepath = os.path.join(copy_info.input_dir.url, copy_info.old_filename)
output_filepath = os.path.join(output_dir.url, copy_info.new_filename)

fs_provider = _get_fs_provider(output_dir.url)
fs_provider = _get_fs_provider(output_dir.url, storage_options)
fs_provider.copy(input_filepath, output_filepath)
else:
raise NotImplementedError


def _save_index(index_json: Dict, output_dir: Dir) -> None:
def _save_index(index_json: Dict, output_dir: Dir, storage_options: Dict[str, Any] = {}) -> None:
if output_dir.url is None:
assert output_dir.path
with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f:
Expand All @@ -663,5 +691,5 @@ def _save_index(index_json: Dict, output_dir: Dir) -> None:

remote_path = os.path.join(output_dir.url, _INDEX_FILENAME)

fs_provider = _get_fs_provider(output_dir.url)
fs_provider = _get_fs_provider(output_dir.url, storage_options)
fs_provider.upload_file(f.name, remote_path)
Loading
Loading