Skip to content

Commit 1d5851f

Browse files
tchatonethanwharrislantigathomaspre-commit-ci[bot]
authored
Introduce Cache 1/n (#18642)
Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: Luca Antiga <[email protected]> Co-authored-by: thomas <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6537a05 commit 1d5851f

File tree

21 files changed

+2267
-20
lines changed

21 files changed

+2267
-20
lines changed

.github/checkgroup.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ subprojects:
173173
- "!*.md"
174174
- "!**/*.md"
175175
checks:
176-
- "data-cpu (macOS-11, lightning, 3.10, 2.0)"
177-
- "data-cpu (ubuntu-20.04, lightning, 3.10, 2.0)"
178-
- "data-cpu (windows-2022, lightning, 3.10, 2.0)"
176+
- "data-cpu (macOS-11, lightning, 3.10, 2.1)"
177+
- "data-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
178+
- "data-cpu (windows-2022, lightning, 3.10, 2.1)"
179179

180180
# SECTION: lightning_fabric
181181

.github/workflows/ci-tests-data.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ jobs:
3434
fail-fast: false
3535
matrix:
3636
include:
37-
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
38-
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
39-
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
37+
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
38+
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
39+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
4040
# "oldest" versions tests, only on minimum Python
4141
# - {os: "macOS-11", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"}
4242
# - {os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"}

requirements/data/data.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33

44
lightning-utilities >=0.8.0, <0.10.0
55
# to be able to include also 0.6 and preserve `>` needed for CI min version bypass
6-
torchdata >0.5.9, <0.7.0
6+
torchdata >0.5.9, <=0.7.0
77
# to be able to include also PL 2.0 and preserve `>` needed for CI min version bypass
8-
torch >0.14.0, <2.1.0
8+
torch >0.14.0, <=2.1.0
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from lightning.data.cache.cache import Cache
15+
from lightning.data.cache.dataloader import LightningDataLoader
16+
17+
__all__ = ["Cache", "LightningDataLoader"]

src/lightning/data/cache/cache.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import logging
15+
import os
16+
from typing import Any, Dict, List, Optional, Tuple, Union
17+
18+
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
19+
from lightning.data.cache.reader import BinaryReader
20+
from lightning.data.cache.sampler import ChunkedIndex
21+
from lightning.data.cache.writer import BinaryWriter
22+
from lightning.data.datasets.env import _DistributedEnv
23+
24+
logger = logging.Logger(__name__)
25+
26+
27+
class Cache:
28+
def __init__(
29+
self,
30+
cache_dir: str,
31+
remote_dir: Optional[str] = None,
32+
compression: Optional[str] = None,
33+
chunk_size: Optional[int] = None,
34+
chunk_bytes: Optional[int] = None,
35+
):
36+
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
37+
together in order to accelerate fetching.
38+
39+
Arguments:
40+
cache_dir: The path to where the chunks will be stored.
41+
remote_dir: The path to a remote folder where the data are located.
42+
The scheme needs to be added to the path.
43+
compression: The name of the algorithm to reduce the size of the chunks.
44+
chunk_bytes: The maximum number of bytes within a chunk.
45+
chunk_size: The maximum number of items within a chunk.
46+
47+
"""
48+
super().__init__()
49+
if not _TORCH_2_1_0_AVAILABLE:
50+
raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.")
51+
self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
52+
self._reader = BinaryReader(cache_dir, remote_dir=remote_dir, compression=compression)
53+
self._cache_dir = cache_dir
54+
self._is_done = False
55+
self._distributed_env = _DistributedEnv.detect()
56+
57+
@property
58+
def filled(self) -> bool:
59+
"""Returns whether the caching phase is done."""
60+
if self._is_done:
61+
return True
62+
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME))
63+
return self._is_done
64+
65+
def __setitem__(self, index: int, data: Any) -> None:
66+
"""Store an item in the writer."""
67+
self._writer[index] = data
68+
69+
def __getitem__(self, index: Union[int, ChunkedIndex]) -> Dict[str, Any]:
70+
"""Read an item in the reader."""
71+
if isinstance(index, int):
72+
index = ChunkedIndex(index, self._get_chunk_index_from_index(index))
73+
return self._reader.read(index)
74+
75+
def done(self) -> None:
76+
"""Inform the writer the chunking phase is finished."""
77+
self._writer.done()
78+
79+
def merge(self, num_workers: int = 1) -> None:
80+
"""Inform the writer the chunking phase is finished."""
81+
self._writer.merge(num_workers)
82+
83+
def __len__(self) -> int:
84+
return self._reader.get_length()
85+
86+
def get_chunk_interval(self) -> List[Tuple[int, int]]:
87+
return self._reader.get_chunk_interval()
88+
89+
def _get_chunk_index_from_index(self, index: int) -> int:
90+
return self._reader._get_chunk_index_from_index(index)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from abc import ABC, abstractclassmethod, abstractmethod
15+
from typing import Dict, TypeVar
16+
17+
from lightning_utilities.core.imports import RequirementCache, requires
18+
19+
_ZSTD_AVAILABLE = RequirementCache("zstd")
20+
21+
if _ZSTD_AVAILABLE:
22+
import zstd
23+
24+
TCompressor = TypeVar("TCompressor", bound="Compressor")
25+
26+
27+
class Compressor(ABC):
28+
"""Base class for compression algorithm."""
29+
30+
@abstractmethod
31+
def compress(self, data: bytes) -> bytes:
32+
pass
33+
34+
@abstractmethod
35+
def decompress(self, data: bytes) -> bytes:
36+
pass
37+
38+
@abstractclassmethod
39+
def register(cls, compressors: Dict[str, "Compressor"]) -> None:
40+
pass
41+
42+
43+
class ZSTDCompressor(Compressor):
44+
"""Compressor for the zstd package."""
45+
46+
@requires("zstd")
47+
def __init__(self, level: int) -> None:
48+
super().__init__()
49+
self.level = level
50+
self.extension = "zstd"
51+
52+
@property
53+
def name(self) -> str:
54+
return f"{self.extension}:{self.level}"
55+
56+
def compress(self, data: bytes) -> bytes:
57+
return zstd.compress(data, self.level)
58+
59+
def decompress(self, data: bytes) -> bytes:
60+
return zstd.decompress(data)
61+
62+
@classmethod
63+
def register(cls, compressors: Dict[str, "Compressor"]) -> None: # type: ignore
64+
if not _ZSTD_AVAILABLE:
65+
return
66+
67+
# default
68+
compressors["zstd"] = ZSTDCompressor(4)
69+
70+
for level in list(range(1, 23)):
71+
compressors[f"zstd:{level}"] = ZSTDCompressor(level)
72+
73+
74+
_COMPRESSORS: Dict[str, Compressor] = {}
75+
76+
ZSTDCompressor.register(_COMPRESSORS)

src/lightning/data/cache/config.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import json
15+
import os
16+
from typing import Any, Dict, List, Optional, Tuple
17+
18+
from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
19+
from lightning.data.cache.downloader import get_downloader_cls
20+
from lightning.data.cache.sampler import ChunkedIndex
21+
22+
if _TORCH_2_1_0_AVAILABLE:
23+
from torch.utils._pytree import treespec_loads
24+
25+
26+
class ChunksConfig:
27+
def __init__(self, cache_dir: str, remote_dir: Optional[str]):
28+
"""The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its
29+
chunk.
30+
31+
Arguments:
32+
cache_dir: The path to cache folder.
33+
remote_dir: The path to a remote folder where the data are located.
34+
The scheme needs to be added to the path.
35+
36+
"""
37+
self._cache_dir = cache_dir
38+
self._intervals: List[Tuple[int, int]] = []
39+
self._config = None
40+
self._chunks = []
41+
self._remote_dir = remote_dir
42+
43+
with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f:
44+
data = json.load(f)
45+
46+
self._config = data["config"]
47+
48+
self._chunks.extend(data["chunks"])
49+
50+
self._config["data_spec"] = treespec_loads(self._config["data_spec"])
51+
52+
for chunk in self._chunks:
53+
start, end = chunk["interval"]
54+
if (end - start) != chunk["chunk_size"]:
55+
raise Exception(
56+
"The config intervals doesn't match the number of samples. This shouldn't have happened."
57+
)
58+
self._intervals.append((chunk["interval"][0], chunk["interval"][1]))
59+
60+
self._length = sum([chunk["chunk_size"] for chunk in self._chunks])
61+
62+
self._downloader = None
63+
64+
if remote_dir:
65+
self._downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, self._chunks)
66+
67+
def download_chunk_from_index(self, chunk_index: int) -> None:
68+
chunk_filename = self._chunks[chunk_index]["filename"]
69+
70+
local_chunkpath = os.path.join(self._cache_dir, chunk_filename)
71+
72+
if os.path.exists(local_chunkpath):
73+
return
74+
75+
if self._downloader is None:
76+
raise RuntimeError("The downloader should be defined.")
77+
78+
self._downloader.download_chunk_from_index(chunk_index)
79+
80+
@property
81+
def intervals(self) -> List[Tuple[int, int]]:
82+
if self._intervals is None:
83+
raise RuntimeError("The intervals should be defined.")
84+
return self._intervals
85+
86+
@property
87+
def data_format(self) -> Any:
88+
if self._config is None:
89+
raise RuntimeError("The config should be defined.")
90+
return self._config["data_format"]
91+
92+
@property
93+
def config(self) -> Dict[str, Any]:
94+
if self._config is None:
95+
raise RuntimeError("The config should be defined.")
96+
return self._config
97+
98+
def _get_chunk_index_from_index(self, index: int) -> int:
99+
for chunk_index, internal in enumerate(self._intervals):
100+
if internal[0] <= index < internal[1]:
101+
return chunk_index
102+
raise ValueError(
103+
f"The provided index {index} didn't find a match within the chunk intervals {self._intervals}."
104+
)
105+
106+
def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
107+
"""Find the associated chunk metadata."""
108+
chunk = self._chunks[index.chunk_index]
109+
return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index]
110+
111+
@classmethod
112+
def load(cls, cache_dir: str, remote_dir: Optional[str] = None) -> Optional["ChunksConfig"]:
113+
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)
114+
115+
if isinstance(remote_dir, str):
116+
downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, [])
117+
downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath)
118+
119+
if not os.path.exists(cache_index_filepath):
120+
return None
121+
122+
return ChunksConfig(cache_dir, remote_dir)
123+
124+
def __len__(self) -> int:
125+
return self._length
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from lightning_utilities.core.imports import RequirementCache
15+
16+
_INDEX_FILENAME = "index.json"
17+
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
18+
19+
# This is required for full pytree serialization / deserialization support
20+
_TORCH_2_1_0_AVAILABLE = RequirementCache("torch>=2.1.0")
21+
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")

0 commit comments

Comments
 (0)