Skip to content

Commit e37652f

Browse files
tchatonthomaspre-commit-ci[bot]
authored andcommitted
Improve map and chunkify (#18901)
Co-authored-by: thomas <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 85933f3)
1 parent 9816536 commit e37652f

File tree

9 files changed

+332
-32
lines changed

9 files changed

+332
-32
lines changed

requirements/app/app.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
lightning-cloud ==0.5.44 # Must be pinned to ensure compatibility
1+
lightning-cloud ==0.5.46 # Must be pinned to ensure compatibility
22
packaging
33
typing-extensions >=4.0.0, <4.8.0
44
deepdiff >=5.7.0, <6.6.0

src/lightning/data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from lightning.data.datasets import LightningDataset, LightningIterableDataset
22
from lightning.data.streaming.dataset import StreamingDataset
3-
from lightning.data.streaming.map import map
3+
from lightning.data.streaming.functions import map, optimize
44

55
__all__ = [
66
"LightningDataset",
77
"StreamingDataset",
88
"LightningIterableDataset",
99
"map",
10+
"optimize",
1011
]

src/lightning/data/streaming/cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
from lightning.data.datasets.env import _DistributedEnv
1919
from lightning.data.streaming.constants import (
2020
_INDEX_FILENAME,
21-
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42,
21+
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46,
2222
_TORCH_GREATER_EQUAL_2_1_0,
2323
)
2424
from lightning.data.streaming.item_loader import BaseItemLoader
2525
from lightning.data.streaming.reader import BinaryReader
2626
from lightning.data.streaming.sampler import ChunkedIndex
2727
from lightning.data.streaming.writer import BinaryWriter
2828

29-
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42:
29+
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46:
3030
from lightning_cloud.resolver import _find_remote_dir, _try_create_cache_dir
3131

3232
logger = logging.Logger(__name__)

src/lightning/data/streaming/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# This is required for full pytree serialization / deserialization support
2222
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
2323
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
24-
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42 = RequirementCache("lightning-cloud>=0.5.42")
24+
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46 = RequirementCache("lightning-cloud>=0.5.46")
2525
_BOTO3_AVAILABLE = RequirementCache("boto3")
2626

2727
# DON'T CHANGE ORDER

src/lightning/data/streaming/data_processor.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import traceback
66
import types
77
from abc import abstractmethod
8+
from dataclasses import dataclass
89
from multiprocessing import Process, Queue
910
from queue import Empty
1011
from shutil import copyfile, rmtree
1112
from time import sleep, time
12-
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
13+
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
1314
from urllib import parse
1415

1516
import torch
@@ -21,7 +22,7 @@
2122
_BOTO3_AVAILABLE,
2223
_DEFAULT_FAST_DEV_RUN_ITEMS,
2324
_INDEX_FILENAME,
24-
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42,
25+
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46,
2526
_TORCH_GREATER_EQUAL_2_1_0,
2627
)
2728
from lightning.fabric.accelerators.cuda import is_cuda_available
@@ -35,7 +36,7 @@
3536
if _TORCH_GREATER_EQUAL_2_1_0:
3637
from torch.utils._pytree import tree_flatten, tree_unflatten
3738

38-
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42:
39+
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46:
3940
from lightning_cloud.resolver import _LightningSrcResolver, _LightningTargetResolver
4041

4142
if _BOTO3_AVAILABLE:
@@ -160,10 +161,14 @@ def _remove_target(input_dir: str, cache_dir: str, queue_in: Queue) -> None:
160161
# 3. Iterate through the paths and delete them sequentially.
161162
for path in paths:
162163
if input_dir:
163-
cached_filepath = path.replace(input_dir, cache_dir)
164+
if not path.startswith(cache_dir):
165+
path = path.replace(input_dir, cache_dir)
164166

165-
if os.path.exists(cached_filepath):
166-
os.remove(cached_filepath)
167+
if os.path.exists(path):
168+
os.remove(path)
169+
170+
elif os.path.exists(path) and "s3_connections" not in path:
171+
os.remove(path)
167172

168173

169174
def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, remote_output_dir: str) -> None:
@@ -387,7 +392,9 @@ def _collect_paths(self) -> None:
387392
}
388393

389394
if len(indexed_paths) == 0:
390-
raise ValueError(f"The provided item {item} didn't contain any filepaths. {flattened_item}")
395+
raise ValueError(
396+
f"The provided item {item} didn't contain any filepaths. The input_dir is {self.input_dir}."
397+
)
391398

392399
paths = []
393400
for index, path in indexed_paths.items():
@@ -548,7 +555,7 @@ def __init__(self) -> None:
548555
def _setup(self, name: Optional[str]) -> None:
549556
self._name = name
550557

551-
def _done(self, delete_cached_files: bool, remote_output_dir: str) -> None:
558+
def _done(self, delete_cached_files: bool, remote_output_dir: Any) -> None:
552559
pass
553560

554561

@@ -578,7 +585,6 @@ def prepare_item(self, item_metadata: T) -> Any: # type: ignore
578585

579586
def _done(self, delete_cached_files: bool, remote_output_dir: str) -> None:
580587
num_nodes = _get_num_nodes()
581-
assert self._name
582588
cache_dir = _get_cache_dir(self._name)
583589

584590
chunks = [file for file in os.listdir(cache_dir) if file.endswith(".bin")]
@@ -647,6 +653,14 @@ def prepare_item(self, output_dir: str, item_metadata: T) -> None: # type: igno
647653
"""Use your item metadata to process your files and save the file outputs into `output_dir`."""
648654

649655

656+
@dataclass
657+
class PrettyDirectory:
658+
"""Holds a directory and its URL."""
659+
660+
directory: str
661+
url: str
662+
663+
650664
class DataProcessor:
651665
def __init__(
652666
self,
@@ -656,10 +670,11 @@ def __init__(
656670
num_downloaders: Optional[int] = None,
657671
delete_cached_files: bool = True,
658672
src_resolver: Optional[Callable[[str], Optional[str]]] = None,
659-
fast_dev_run: Optional[bool] = None,
673+
fast_dev_run: Optional[Union[bool, int]] = None,
660674
remote_input_dir: Optional[str] = None,
661-
remote_output_dir: Optional[str] = None,
675+
remote_output_dir: Optional[Union[str, PrettyDirectory]] = None,
662676
random_seed: Optional[int] = 42,
677+
version: Optional[int] = None,
663678
):
664679
"""The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
665680
training faster.
@@ -692,18 +707,22 @@ def __init__(
692707
self.remote_input_dir = (
693708
str(remote_input_dir)
694709
if remote_input_dir is not None
695-
else ((self.src_resolver(input_dir) if input_dir else None) if self.src_resolver else None)
710+
else ((self.src_resolver(str(input_dir)) if input_dir else None) if self.src_resolver else None)
696711
)
697712
self.remote_output_dir = (
698713
remote_output_dir
699714
if remote_output_dir is not None
700-
else (self.dst_resolver(name) if self.dst_resolver else None)
715+
else (self.dst_resolver(name, version=version) if self.dst_resolver else None)
701716
)
702717
if self.remote_output_dir:
703718
self.name = self._broadcast_object(self.name)
704719
# Ensure the remote src dir is the same across all ranks
705720
self.remote_output_dir = self._broadcast_object(self.remote_output_dir)
706-
print(f"Storing the files under {self.remote_output_dir}")
721+
if isinstance(self.remote_output_dir, PrettyDirectory):
722+
print(f"Storing the files under {self.remote_output_dir.directory}")
723+
self.remote_output_dir = self.remote_output_dir.url
724+
else:
725+
print(f"Storing the files under {self.remote_output_dir}")
707726

708727
self.random_seed = random_seed
709728

@@ -725,7 +744,7 @@ def run(self, data_recipe: DataRecipe) -> None:
725744
user_items: List[Any] = data_recipe.prepare_structure(self.input_dir)
726745

727746
if not isinstance(user_items, list):
728-
raise ValueError("The setup_fn should return a list of item metadata.")
747+
raise ValueError("The `prepare_structure` should return a list of item metadata.")
729748

730749
# Associate the items to the workers based on num_nodes and node_rank
731750
begins, workers_user_items = _associated_items_to_workers(self.num_workers, user_items)
@@ -779,6 +798,8 @@ def run(self, data_recipe: DataRecipe) -> None:
779798
w.join(0)
780799

781800
print("Workers are finished.")
801+
if self.remote_output_dir:
802+
assert isinstance(self.remote_output_dir, str)
782803
data_recipe._done(self.delete_cached_files, self.remote_output_dir)
783804
print("Finished data processing!")
784805

@@ -856,15 +877,15 @@ def _cleanup_cache(self) -> None:
856877

857878
# Cleanup the cache dir folder to avoid corrupted files from previous run to be there.
858879
if os.path.exists(cache_dir):
859-
rmtree(cache_dir)
880+
rmtree(cache_dir, ignore_errors=True)
860881

861882
os.makedirs(cache_dir, exist_ok=True)
862883

863884
cache_data_dir = _get_cache_data_dir(self.name)
864885

865886
# Cleanup the cache data folder to avoid corrupted files from previous run to be there.
866887
if os.path.exists(cache_data_dir):
867-
rmtree(cache_data_dir)
888+
rmtree(cache_data_dir, ignore_errors=True)
868889

869890
os.makedirs(cache_data_dir, exist_ok=True)
870891

0 commit comments

Comments
 (0)