Skip to content

Commit 0c8a1f5

Browse files
tchatonthomaspre-commit-ci[bot]
authored andcommitted
Add name and version (#18796)
Co-authored-by: thomas <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 3f86ad7)
1 parent 23260a6 commit 0c8a1f5

20 files changed

+160
-52
lines changed

requirements/app/app.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
lightning-cloud ==0.5.41 # Must be pinned to ensure compatibility
1+
lightning-cloud ==0.5.42 # Must be pinned to ensure compatibility
22
packaging
33
typing-extensions >=4.0.0, <4.8.0
44
deepdiff >=5.7.0, <6.6.0

src/lightning/data/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from lightning.data.datasets import LightningDataset, LightningIterableDataset
2+
from lightning.data.streaming.dataloader import StreamingDataLoader
3+
from lightning.data.streaming.dataset import StreamingDataset
24

3-
__all__ = ["LightningDataset", "LightningIterableDataset"]
5+
__all__ = ["LightningDataset", "StreamingDataset", "StreamingDataLoader", "LightningIterableDataset"]

src/lightning/data/cache/__init__.py renamed to src/lightning/data/streaming/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
from lightning.data.cache.cache import Cache
15-
from lightning.data.cache.dataloader import LightningDataLoader
16-
from lightning.data.cache.dataset_optimizer import DatasetOptimizer
14+
from lightning.data.streaming.cache import Cache
15+
from lightning.data.streaming.dataloader import StreamingDataLoader
16+
from lightning.data.streaming.dataset_optimizer import DatasetOptimizer
1717

18-
__all__ = ["Cache", "DatasetOptimizer", "LightningDataLoader"]
18+
__all__ = ["Cache", "DatasetOptimizer", "StreamingDataLoader"]

src/lightning/data/cache/cache.py renamed to src/lightning/data/streaming/cache.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,31 @@
1313

1414
import logging
1515
import os
16-
from typing import Any, Dict, List, Optional, Tuple, Union
16+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
1717

18-
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
19-
from lightning.data.cache.reader import BinaryReader
20-
from lightning.data.cache.sampler import ChunkedIndex
21-
from lightning.data.cache.writer import BinaryWriter
2218
from lightning.data.datasets.env import _DistributedEnv
19+
from lightning.data.streaming.constants import (
20+
_INDEX_FILENAME,
21+
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42,
22+
_TORCH_GREATER_EQUAL_2_1_0,
23+
)
24+
from lightning.data.streaming.reader import BinaryReader
25+
from lightning.data.streaming.sampler import ChunkedIndex
26+
from lightning.data.streaming.writer import BinaryWriter
27+
28+
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42:
29+
from lightning_cloud.resolver import _find_remote_dir, _try_create_cache_dir
2330

2431
logger = logging.Logger(__name__)
2532

2633

2734
class Cache:
2835
def __init__(
2936
self,
30-
cache_dir: str,
37+
cache_dir: Optional[str] = None,
3138
remote_dir: Optional[str] = None,
39+
name: Optional[str] = None,
40+
version: Optional[Union[int, Literal["latest"]]] = "latest",
3241
compression: Optional[str] = None,
3342
chunk_size: Optional[int] = None,
3443
chunk_bytes: Optional[int] = None,
@@ -40,6 +49,8 @@ def __init__(
4049
cache_dir: The path to where the chunks will be stored.
4150
remote_dir: The path to a remote folder where the data are located.
4251
The scheme needs to be added to the path.
52+
name: The name of dataset in the cloud.
53+
version: The version of the dataset in the cloud to use. By default, we will use the latest.
4354
compression: The name of the algorithm to reduce the size of the chunks.
4455
chunk_bytes: The maximum number of bytes within a chunk.
4556
chunk_size: The maximum number of items within a chunk.
@@ -48,10 +59,20 @@ def __init__(
4859
super().__init__()
4960
if not _TORCH_GREATER_EQUAL_2_1_0:
5061
raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.")
62+
63+
cache_dir = cache_dir if cache_dir else _try_create_cache_dir(name)
64+
if not remote_dir:
65+
remote_dir, has_index_file = _find_remote_dir(name, version)
66+
67+
# When the index exists, we don't care about the chunk_size anymore.
68+
if has_index_file and (chunk_size is None and chunk_bytes is None):
69+
chunk_size = 2
5170
self._writer = BinaryWriter(
5271
str(cache_dir), chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression
5372
)
54-
self._reader = BinaryReader(str(cache_dir), remote_dir=remote_dir, compression=compression)
73+
self._reader = BinaryReader(
74+
str(cache_dir), remote_dir=remote_dir, compression=compression, name=name, version=version
75+
)
5576
self._cache_dir = str(cache_dir)
5677
self._is_done = False
5778
self._distributed_env = _DistributedEnv.detect()

src/lightning/data/cache/config.py renamed to src/lightning/data/streaming/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import os
1616
from typing import Any, Dict, List, Optional, Tuple
1717

18-
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
19-
from lightning.data.cache.downloader import get_downloader_cls
20-
from lightning.data.cache.sampler import ChunkedIndex
18+
from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
19+
from lightning.data.streaming.downloader import get_downloader_cls
20+
from lightning.data.streaming.sampler import ChunkedIndex
2121

2222
if _TORCH_GREATER_EQUAL_2_1_0:
2323
from torch.utils._pytree import treespec_loads

src/lightning/data/cache/constants.py renamed to src/lightning/data/streaming/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@
2020
# This is required for full pytree serialization / deserialization support
2121
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
2222
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
23-
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_41 = RequirementCache("lightning-cloud>=0.5.41")
23+
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42 = RequirementCache("lightning-cloud>=0.5.42")
2424
_BOTO3_AVAILABLE = RequirementCache("boto3")

src/lightning/data/cache/dataloader.py renamed to src/lightning/data/streaming/dataloader.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
)
3232
from torch.utils.data.sampler import BatchSampler, Sampler
3333

34-
from lightning.data.cache import Cache
35-
from lightning.data.cache.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE
36-
from lightning.data.cache.sampler import CacheBatchSampler
3734
from lightning.data.datasets.env import _DistributedEnv
35+
from lightning.data.streaming import Cache
36+
from lightning.data.streaming.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE
37+
from lightning.data.streaming.sampler import CacheBatchSampler
3838

3939
if _TORCH_GREATER_EQUAL_2_1_0:
4040
from torch.utils._pytree import tree_flatten
@@ -172,7 +172,7 @@ def __call__(
172172
) -> None:
173173
from torch.utils.data._utils import worker
174174

175-
from lightning.data.cache.cache import Cache
175+
from lightning.data.streaming.cache import Cache
176176

177177
enable_profiling = self._global_rank == 0 and worker_id == 0 and _VIZ_TRACKER_AVAILABLE and self._profile
178178

@@ -248,7 +248,7 @@ def _next_data(self) -> Any:
248248
raise e
249249

250250

251-
class LightningDataLoader(DataLoader):
251+
class StreamingDataLoader(DataLoader):
252252
__doc__ = DataLoader.__doc__
253253

254254
def __init__(
@@ -271,16 +271,16 @@ def __init__(
271271
) -> None:
272272
if sampler:
273273
raise ValueError(
274-
"The LightningDataLoader relies on its own internal sampler. Passing a sampler isn't supported."
274+
"The StreamingDataLoader relies on its own internal sampler. Passing a sampler isn't supported."
275275
)
276276

277277
if batch_sampler:
278278
raise ValueError(
279-
"The LightningDataLoader relies on its own internal sampler. Passing a batch_sampler isn't supported."
279+
"The StreamingDataLoader relies on its own internal sampler. Passing a batch_sampler isn't supported."
280280
)
281281

282282
if isinstance(dataset, IterableDataset):
283-
raise ValueError("Only map-based dataset are supported by the LightningDataLoader for now.")
283+
raise ValueError("Only map-based dataset are supported by the StreamingDataLoader for now.")
284284

285285
if profile and not _VIZ_TRACKER_AVAILABLE:
286286
raise ModuleNotFoundError("To enable DataLoader profiling, run `pip install viztracer`.")
@@ -294,7 +294,7 @@ def __init__(
294294

295295
if len(cache_list) == 0:
296296
if cache_dir is None:
297-
raise ValueError("You should provide a `cache_dir` filepath to the LightningDataLoader.")
297+
raise ValueError("You should provide a `cache_dir` filepath to the StreamingDataLoader.")
298298

299299
dataset = CacheDataset(dataset, cache_dir, chunk_bytes, batch_size, compression)
300300
cache = dataset._cache
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from typing import Any, Literal, Optional, Union
15+
16+
from torch.utils.data import Dataset
17+
18+
from lightning.data.streaming import Cache
19+
20+
21+
class StreamingDataset(Dataset):
22+
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class."""
23+
24+
def __init__(
25+
self, name: str, version: Optional[Union[int, Literal["latest"]]] = "latest", cache_dir: Optional[str] = None
26+
) -> None:
27+
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
28+
29+
Arguments:
30+
name: The name of the optimised dataset.
31+
version: The version of the dataset to use.
32+
cache_dir: The cache dir where the data would be stored.
33+
34+
"""
35+
super().__init__()
36+
self.cache = Cache(name=name, version=version, cache_dir=cache_dir)
37+
38+
def __len__(self) -> int:
39+
return len(self.cache)
40+
41+
def __getitem__(self, idx: int) -> Any:
42+
return self.getitem(self.cache[idx])
43+
44+
def getitem(self, obj: Any) -> Any:
45+
"""Override the getitem with your own logic to transform the cache object."""
46+
return obj

src/lightning/data/cache/dataset_optimizer.py renamed to src/lightning/data/streaming/dataset_optimizer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,19 @@
1616
from tqdm import tqdm
1717

1818
from lightning import seed_everything
19-
from lightning.data.cache import Cache
20-
from lightning.data.cache.constants import (
19+
from lightning.data.streaming import Cache
20+
from lightning.data.streaming.constants import (
2121
_BOTO3_AVAILABLE,
2222
_DEFAULT_FAST_DEV_RUN_ITEMS,
2323
_INDEX_FILENAME,
24-
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_41,
24+
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42,
2525
_TORCH_GREATER_EQUAL_2_1_0,
2626
)
2727

2828
if _TORCH_GREATER_EQUAL_2_1_0:
2929
from torch.utils._pytree import tree_flatten, tree_unflatten
3030

31-
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_41:
31+
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42:
3232
from lightning_cloud.resolver import _LightningSrcResolver, _LightningTargetResolver
3333

3434
if _BOTO3_AVAILABLE:
@@ -441,6 +441,13 @@ def prepare_dataset_structure(self, src_dir, filepaths)
441441
# [('file_1.JPEG', 'file_1.mask'), ... ('file_N.JPEG', 'file_N.mask')]
442442
return [(x[i], x[i+1]) for i in range(len(filepaths) -1)]
443443
444+
def prepare_item(self, obj):
445+
image_filepath, mask_filepath = obj
446+
447+
image = load_and_resize(image_filepath)
448+
mask = load_and_resize(mask_filepath)
449+
return (image, mask)
450+
444451
"""
445452
pass
446453

0 commit comments

Comments
 (0)