Skip to content

Commit 9bb08c3

Browse files
tchatonthomaspre-commit-ci[bot]
authored andcommitted
Introduce Dataset Optimizer (#18788)
Co-authored-by: thomas <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 142977d)
1 parent 3aafbea commit 9bb08c3

File tree

14 files changed

+1297
-96
lines changed

14 files changed

+1297
-96
lines changed

requirements/app/app.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
lightning-cloud ==0.5.39 # Must be pinned to ensure compatibility
1+
lightning-cloud ==0.5.41 # Must be pinned to ensure compatibility
22
packaging
33
typing-extensions >=4.0.0, <4.8.0
44
deepdiff >=5.7.0, <6.6.0

src/lightning/data/cache/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313

1414
from lightning.data.cache.cache import Cache
1515
from lightning.data.cache.dataloader import LightningDataLoader
16+
from lightning.data.cache.dataset_optimizer import DatasetOptimizer
1617

17-
__all__ = ["Cache", "LightningDataLoader"]
18+
__all__ = ["Cache", "DatasetOptimizer", "LightningDataLoader"]

src/lightning/data/cache/cache.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
from typing import Any, Dict, List, Optional, Tuple, Union
1717

18-
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
18+
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
1919
from lightning.data.cache.reader import BinaryReader
2020
from lightning.data.cache.sampler import ChunkedIndex
2121
from lightning.data.cache.writer import BinaryWriter
@@ -46,11 +46,13 @@ def __init__(
4646
4747
"""
4848
super().__init__()
49-
if not _TORCH_2_1_0_AVAILABLE:
49+
if not _TORCH_GREATER_EQUAL_2_1_0:
5050
raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.")
51-
self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
52-
self._reader = BinaryReader(cache_dir, remote_dir=remote_dir, compression=compression)
53-
self._cache_dir = cache_dir
51+
self._writer = BinaryWriter(
52+
str(cache_dir), chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression
53+
)
54+
self._reader = BinaryReader(str(cache_dir), remote_dir=remote_dir, compression=compression)
55+
self._cache_dir = str(cache_dir)
5456
self._is_done = False
5557
self._distributed_env = _DistributedEnv.detect()
5658

@@ -66,19 +68,27 @@ def __setitem__(self, index: int, data: Any) -> None:
6668
"""Store an item in the writer."""
6769
self._writer[index] = data
6870

71+
def _add_item(self, index: int, data: Any) -> Optional[str]:
72+
"""Store an item in the writer and optionally return the chunk path."""
73+
return self._writer.add_item(index, data)
74+
6975
def __getitem__(self, index: Union[int, ChunkedIndex]) -> Dict[str, Any]:
7076
"""Read an item in the reader."""
7177
if isinstance(index, int):
7278
index = ChunkedIndex(index, self._get_chunk_index_from_index(index))
7379
return self._reader.read(index)
7480

75-
def done(self) -> None:
81+
def done(self) -> Optional[List[str]]:
82+
"""Inform the writer the chunking phase is finished."""
83+
return self._writer.done()
84+
85+
def merge(self, num_workers: int = 1, node_rank: Optional[int] = None) -> None:
7686
"""Inform the writer the chunking phase is finished."""
77-
self._writer.done()
87+
self._writer.merge(num_workers, node_rank=node_rank)
7888

79-
def merge(self, num_workers: int = 1) -> None:
89+
def _merge_no_wait(self, node_rank: Optional[int] = None) -> None:
8090
"""Inform the writer the chunking phase is finished."""
81-
self._writer.merge(num_workers)
91+
self._writer._merge_no_wait(node_rank=node_rank)
8292

8393
def __len__(self) -> int:
8494
return self._reader.get_length()

src/lightning/data/cache/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
import os
1616
from typing import Any, Dict, List, Optional, Tuple
1717

18-
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
18+
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
1919
from lightning.data.cache.downloader import get_downloader_cls
2020
from lightning.data.cache.sampler import ChunkedIndex
2121

22-
if _TORCH_2_1_0_AVAILABLE:
22+
if _TORCH_GREATER_EQUAL_2_1_0:
2323
from torch.utils._pytree import treespec_loads
2424

2525

@@ -54,6 +54,7 @@ def __init__(self, cache_dir: str, remote_dir: Optional[str]):
5454
if (end - start) != chunk["chunk_size"]:
5555
raise Exception(
5656
"The config intervals doesn't match the number of samples. This shouldn't have happened."
57+
f" Found {end} {start} {chunk['chunk_size']}"
5758
)
5859
self._intervals.append((chunk["interval"][0], chunk["interval"][1]))
5960

src/lightning/data/cache/constants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515

1616
_INDEX_FILENAME = "index.json"
1717
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
18+
_DEFAULT_FAST_DEV_RUN_ITEMS = 10
1819

1920
# This is required for full pytree serialization / deserialization support
20-
_TORCH_2_1_0_AVAILABLE = RequirementCache("torch>=2.1.0")
21+
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
2122
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
23+
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_41 = RequirementCache("lightning-cloud>=0.5.41")
24+
_BOTO3_AVAILABLE = RequirementCache("boto3")

src/lightning/data/cache/dataloader.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@
3232
from torch.utils.data.sampler import BatchSampler, Sampler
3333

3434
from lightning.data.cache import Cache
35-
from lightning.data.cache.constants import _DEFAULT_CHUNK_BYTES, _TORCH_2_1_0_AVAILABLE, _VIZ_TRACKER_AVAILABLE
35+
from lightning.data.cache.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE
3636
from lightning.data.cache.sampler import CacheBatchSampler
37-
from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
37+
from lightning.data.datasets.env import _DistributedEnv
3838

39-
if _TORCH_2_1_0_AVAILABLE:
39+
if _TORCH_GREATER_EQUAL_2_1_0:
4040
from torch.utils._pytree import tree_flatten
4141

4242
logger = logging.Logger(__name__)
@@ -154,13 +154,27 @@ def __init__(self, global_rank: int, profile: bool = False) -> None:
154154
self._global_rank = global_rank
155155
self._profile = profile
156156

157-
def __call__(self, dataset_kind: _DatasetKind, *args: Any, **kwargs: Any) -> None:
157+
def __call__(
158+
self,
159+
dataset_kind: Any,
160+
dataset: Any,
161+
index_queue: Any,
162+
data_queue: Any,
163+
done_event: Any,
164+
auto_collation: Any,
165+
collate_fn: Any,
166+
drop_last: Any,
167+
base_seed: Any,
168+
init_fn: Any,
169+
worker_id: Any,
170+
*args: Any,
171+
**kwargs: Any,
172+
) -> None:
158173
from torch.utils.data._utils import worker
159174

160175
from lightning.data.cache.cache import Cache
161176

162-
rank = _WorkerEnv.detect().rank
163-
enable_profiling = self._global_rank == 0 and rank == 0 and _VIZ_TRACKER_AVAILABLE and self._profile
177+
enable_profiling = self._global_rank == 0 and worker_id == 0 and _VIZ_TRACKER_AVAILABLE and self._profile
164178

165179
if enable_profiling:
166180
from viztracer import VizTracer
@@ -180,7 +194,21 @@ def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher":
180194

181195
_DatasetKind.create_fetcher = create_fetcher_fn # type: ignore
182196

183-
reloaded_worker._worker_loop(dataset_kind, *args, **kwargs)
197+
reloaded_worker._worker_loop(
198+
dataset_kind,
199+
dataset,
200+
index_queue,
201+
data_queue,
202+
done_event,
203+
auto_collation,
204+
collate_fn,
205+
drop_last,
206+
base_seed,
207+
init_fn,
208+
worker_id,
209+
*args,
210+
**kwargs,
211+
)
184212

185213
if dataset_kind == _DatasetKind.Map:
186214
assert fetcher

0 commit comments

Comments
 (0)