Skip to content

Commit ce82dc0

Browse files
tchatonlantigapre-commit-ci[bot]thomasBorda
committed
Add distributed support for StreamingDataset (#18850)
Co-authored-by: Luca Antiga <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit 8748258)
1 parent 020f6f0 commit ce82dc0

File tree

10 files changed

+450
-37
lines changed

10 files changed

+450
-37
lines changed

src/lightning/data/streaming/cache.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def __init__(
7171
if has_index_file and (chunk_size is None and chunk_bytes is None):
7272
chunk_size = 2
7373

74+
# Add the version to the cache_dir to avoid collisions.
75+
if remote_dir and os.path.basename(remote_dir).startswith("version_"):
76+
cache_dir = os.path.join(cache_dir, os.path.basename(remote_dir))
77+
7478
if cache_dir:
7579
os.makedirs(cache_dir, exist_ok=True)
7680

@@ -116,8 +120,8 @@ def _merge_no_wait(self, node_rank: Optional[int] = None) -> None:
116120
def __len__(self) -> int:
117121
return self._reader.get_length()
118122

119-
def get_chunk_interval(self) -> List[Tuple[int, int]]:
120-
return self._reader.get_chunk_interval()
123+
def get_chunk_intervals(self) -> List[Tuple[int, int]]:
124+
return self._reader.get_chunk_intervals()
121125

122126
def _get_chunk_index_from_index(self, index: int) -> int:
123127
return self._reader._get_chunk_index_from_index(index)

src/lightning/data/streaming/dataset.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from lightning.data.streaming import Cache
2121
from lightning.data.streaming.item_loader import BaseItemLoader
2222
from lightning.data.streaming.sampler import ChunkedIndex
23+
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle, TruncatedShuffle
2324

2425

2526
class StreamingDataset(IterableDataset):
@@ -31,7 +32,7 @@ def __init__(
3132
version: Optional[Union[int, Literal["latest"]]] = "latest",
3233
cache_dir: Optional[str] = None,
3334
item_loader: Optional[BaseItemLoader] = None,
34-
shuffle: bool = True,
35+
shuffle: Union[bool, Literal["truncated", "full"]] = "truncated",
3536
seed: int = 42,
3637
) -> None:
3738
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
@@ -53,13 +54,21 @@ def __init__(
5354
if not self.cache.filled:
5455
raise ValueError(f"The provided dataset `{name}` isn't filled up.")
5556

56-
self.shuffle = shuffle
5757
self.distributed_env = _DistributedEnv.detect()
58-
self.worker_env: Optional[_WorkerEnv] = None
5958

60-
chunk_intervals = self.cache.get_chunk_interval()
61-
self.L = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
59+
if isinstance(shuffle, bool):
60+
_shuffle = TruncatedShuffle(self.cache, seed) if shuffle else NoShuffle(self.cache, seed)
61+
62+
if isinstance(shuffle, str):
63+
if shuffle == "truncated":
64+
_shuffle = TruncatedShuffle(self.cache, seed)
65+
elif shuffle == "full":
66+
_shuffle = FullShuffle(self.cache, seed)
67+
else:
68+
raise ValueError(f"The provided shuffle doesn't exist. Found {shuffle}")
6269

70+
self.shuffle: Shuffle = _shuffle
71+
self.worker_env: Optional[_WorkerEnv] = None
6372
self.worker_chunks: List[int] = []
6473
self.worker_intervals: List[List[int]] = []
6574
self.current_indexes: List[int] = []
@@ -68,26 +77,16 @@ def __init__(
6877
self.has_triggered_download = False
6978
self.min_items_per_replica: Optional[int] = None
7079
self.seed = seed
71-
self.num_iter = 0
80+
self.current_epoch = 0
7281
self.random_state = None
7382

7483
def __len__(self) -> int:
75-
return self.L
84+
return self.shuffle.get_len(self.distributed_env, self.current_epoch)
7685

7786
def __iter__(self) -> "StreamingDataset":
78-
self.random_state = np.random.RandomState(seed=self.seed + self.num_iter) # type: ignore
79-
chunk_intervals = self.cache.get_chunk_interval()
80-
indexes = range(len(chunk_intervals))
81-
shuffled_indexes = self.random_state.permutation(indexes) if self.shuffle else list(indexes)
82-
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
83-
84-
chunks_per_replica: List[List[int]] = [[] for _ in range(self.distributed_env.world_size)]
85-
intervals_per_replica: List[List[List[int]]] = [[] for _ in range(self.distributed_env.world_size)]
86-
for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
87-
replica_index = index % self.distributed_env.world_size
88-
chunks_per_replica[replica_index].append(chunk_index)
89-
intervals_per_replica[replica_index].append(chunk_interval)
90-
87+
chunks_per_replica, intervals_per_replica = self.shuffle.get_chunks_and_intervals_per_process(
88+
self.distributed_env, self.current_epoch
89+
)
9190
current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
9291
current_intervals = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
9392

@@ -105,7 +104,7 @@ def __iter__(self) -> "StreamingDataset":
105104

106105
self.current_indexes = []
107106
self.chunk_index = 0
108-
self.num_iter += 1
107+
self.index = 0
109108

110109
return self
111110

@@ -115,16 +114,20 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
115114
return self.cache[index]
116115

117116
def __next__(self) -> Any:
117+
# Prevent to create more batch on a given process
118+
if self.index >= len(self):
119+
self.current_epoch += 1
120+
raise StopIteration
121+
118122
# Lazily re-populate the interval to reduce memory usage.
119123
if len(self.current_indexes) == 0:
120124
if self.chunk_index == len(self.worker_intervals):
125+
self.current_epoch += 1
121126
raise StopIteration
122127

123128
interval = self.worker_intervals[self.chunk_index]
124-
current_indexes = np.arange(0, interval[1] - interval[0])
125-
if self.shuffle:
126-
current_indexes = self.random_state.permutation(current_indexes)
127-
self.current_indexes = current_indexes.tolist()
129+
current_indexes = np.arange(interval[0], interval[1])
130+
self.current_indexes = self.shuffle(current_indexes)
128131
self.chunk_index += 1
129132

130133
# Get the first index

src/lightning/data/streaming/dataset_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def __init__(
485485
self.name = name
486486
self.src_dir = str(src_dir)
487487
self.num_workers = num_workers or (1 if fast_dev_run else (os.cpu_count() or 1) * 4)
488-
self.num_downloaders = num_downloaders or (1 if fast_dev_run else 2)
488+
self.num_downloaders = num_downloaders or 1
489489
if chunk_size is not None and chunk_bytes is not None:
490490
raise ValueError("Either one of the `chunk_size` or the `chunk_bytes` need to be provided.")
491491
self.chunk_size = chunk_size

src/lightning/data/streaming/item_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def generate_intervals(self) -> List[Tuple[int, int]]:
122122

123123
return self._intervals
124124

125-
def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, _: int) -> torch.Tensor:
125+
def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor:
126126
while not os.path.exists(chunk_filepath):
127127
sleep(0.0001)
128128

@@ -137,5 +137,5 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
137137
assert self._dtype
138138

139139
buffer: bytes = self._buffers[chunk_index]
140-
offset = self._dtype.itemsize * index
140+
offset = self._dtype.itemsize * ((index - begin) if index >= begin else index + 1)
141141
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)

src/lightning/data/streaming/reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def get_length(self) -> int:
166166

167167
return len(self.config)
168168

169-
def get_chunk_interval(self) -> List[Tuple[int, int]]:
169+
def get_chunk_intervals(self) -> List[Tuple[int, int]]:
170170
"""Get the index interval of each chunk."""
171171
if self._config is None and self._try_load_config() is None:
172172
raise Exception("The reader index isn't defined.")

src/lightning/data/streaming/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,13 @@ def __iter_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
146146
yield from self.__iter_indices_per_workers__(worker_indices_batches)
147147

148148
def __iter_from_chunks_non_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
149-
chunk_intervals = self._cache.get_chunk_interval()
149+
chunk_intervals = self._cache.get_chunk_intervals()
150150
shuffled_indexes = np.random.permutation(range(len(chunk_intervals)))
151151
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
152152
yield from self.__iter_from_shuffled_chunks(shuffled_indexes.tolist(), shuffled_chunk_intervals)
153153

154154
def __iter_from_chunks_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]:
155-
chunk_intervals = self._cache.get_chunk_interval()
155+
chunk_intervals = self._cache.get_chunk_intervals()
156156
shuffled_indexes = np.random.permutation(range(len(chunk_intervals)))
157157
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
158158

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from abc import ABC, abstractmethod
15+
from functools import lru_cache
16+
from typing import Any, List
17+
18+
import numpy as np
19+
20+
from lightning.data.datasets.env import _DistributedEnv
21+
from lightning.data.streaming import Cache
22+
23+
24+
class Shuffle(ABC):
25+
"""Shuffle describe how to distribute chunked datasets across processes and workers."""
26+
27+
def __init__(self, cache: Cache, seed: int):
28+
self.cache = cache
29+
self.seed = seed
30+
self.random_state = None
31+
32+
@abstractmethod
33+
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
34+
pass
35+
36+
@abstractmethod
37+
def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
38+
pass
39+
40+
@abstractmethod
41+
def __call__(self, array: np.ndarray) -> List[int]:
42+
pass
43+
44+
45+
class NoShuffle(Shuffle):
46+
"""NoShuffle doesn't shuffle the items and ensure all the processes receive the same number of items."""
47+
48+
@lru_cache(maxsize=10)
49+
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
50+
_, intervals_per_process = self.get_chunks_and_intervals_per_process(distributed_env, current_epoch)
51+
min_items_per_process = min(
52+
[sum([(interval[-1] - interval[0]) for interval in intervals]) for intervals in intervals_per_process]
53+
)
54+
return min_items_per_process
55+
56+
@lru_cache(maxsize=10)
57+
def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
58+
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
59+
chunk_intervals = self.cache.get_chunk_intervals()
60+
indexes = list(range(len(chunk_intervals)))
61+
shuffled_chunk_intervals = np.asarray(chunk_intervals)[indexes]
62+
63+
chunks_per_process: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
64+
intervals_per_process: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
65+
for index, (chunk_index, chunk_interval) in enumerate(zip(indexes, shuffled_chunk_intervals)):
66+
replica_index = index % distributed_env.world_size
67+
chunks_per_process[replica_index].append(chunk_index)
68+
intervals_per_process[replica_index].append(chunk_interval)
69+
70+
return chunks_per_process, intervals_per_process
71+
72+
def __call__(self, array: np.ndarray) -> List[int]:
73+
return array.tolist()
74+
75+
76+
class TruncatedShuffle(Shuffle):
77+
"""TruncatedShuffle shuffles the chunks and associates them to the ranks.
78+
79+
As the number of items in a chunk varies, it is possible for a rank to end up with more or less items.
80+
81+
To ensure the same fixed dataset length for all ranks, we compute the minimum number of items across all ranks.
82+
83+
For the ranks with more items than the minimum, the remaining items are dropped.
84+
85+
Note: This is the fastest sampling strategy but at the cost of losing items.
86+
87+
"""
88+
89+
@lru_cache(maxsize=10)
90+
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
91+
_, intervals_per_process = self.get_chunks_and_intervals_per_process(distributed_env, current_epoch)
92+
min_items_per_process = min(
93+
[sum([(interval[-1] - interval[0]) for interval in intervals]) for intervals in intervals_per_process]
94+
)
95+
return min_items_per_process
96+
97+
@lru_cache(maxsize=10)
98+
def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
99+
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
100+
chunk_intervals = self.cache.get_chunk_intervals()
101+
indexes = range(len(chunk_intervals))
102+
shuffled_indexes = self.random_state.permutation(indexes)
103+
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
104+
105+
chunks_per_process: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
106+
intervals_per_process: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
107+
for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
108+
replica_index = index % distributed_env.world_size
109+
chunks_per_process[replica_index].append(chunk_index)
110+
intervals_per_process[replica_index].append(chunk_interval)
111+
112+
return chunks_per_process, intervals_per_process
113+
114+
def __call__(self, array: np.ndarray) -> List[int]:
115+
assert self.random_state
116+
return self.random_state.permutation(array).tolist()
117+
118+
119+
class FullShuffle(Shuffle):
120+
"""FullShuffle shuffles the chunks and associates them to the ranks.
121+
122+
As the number of items in a chunk varies, it is possible for a rank to end up with more or less items.
123+
124+
To ensure the same fixed dataset length for all ranks while dropping as few items as possible,
125+
126+
we adopt the following strategy.
127+
128+
We compute the maximum number of items per rank (M) and iterate through the chunks and ranks
129+
130+
until we have associated at least M items per rank.
131+
132+
As a result, we lose at most (number of ranks) items. However, as some chunks are shared across ranks. This leads to
133+
the same chunk to be downloaded multiple times.
134+
135+
"""
136+
137+
@lru_cache(maxsize=10)
138+
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
139+
_, intervals_per_process = self.get_chunks_and_intervals_per_process(distributed_env, current_epoch)
140+
min_items_per_process = min([sum([(i[-1] - i[0]) for i in intervals]) for intervals in intervals_per_process])
141+
return min_items_per_process
142+
143+
@lru_cache(maxsize=10)
144+
def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
145+
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
146+
chunk_intervals = self.cache.get_chunk_intervals()
147+
indexes = range(len(chunk_intervals))
148+
shuffled_indexes = self.random_state.permutation(indexes)
149+
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
150+
151+
num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
152+
num_items_per_process: List[int] = [
153+
num_items // distributed_env.world_size for _ in range(distributed_env.world_size)
154+
]
155+
chunks_per_process: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
156+
intervals_per_process: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
157+
for chunk_index, chunk_interval in zip(shuffled_indexes, shuffled_chunk_intervals):
158+
process_index = 0
159+
160+
while True:
161+
if process_index == len(num_items_per_process):
162+
break
163+
164+
items_left_to_assign = num_items_per_process[process_index]
165+
166+
if items_left_to_assign == 0:
167+
process_index += 1
168+
continue
169+
170+
items_in_chunk = chunk_interval[-1] - chunk_interval[0]
171+
172+
if items_in_chunk == 0:
173+
break
174+
175+
if items_in_chunk > items_left_to_assign:
176+
chunks_per_process[process_index].append(chunk_index)
177+
begin, end = chunk_interval
178+
intervals_per_process[process_index].append([begin, begin + items_left_to_assign])
179+
chunk_interval = (begin + items_left_to_assign + 1, end)
180+
num_items_per_process[process_index] = 0
181+
process_index += 1
182+
else:
183+
chunks_per_process[process_index].append(chunk_index)
184+
intervals_per_process[process_index].append(chunk_interval)
185+
num_items_per_process[process_index] -= items_in_chunk
186+
break
187+
188+
return chunks_per_process, intervals_per_process
189+
190+
def __call__(self, array: np.ndarray) -> List[int]:
191+
assert self.random_state
192+
return self.random_state.permutation(array).tolist()

tests/tests_data/streaming/test_cache.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,13 @@ def test_cache_with_name(tmpdir, monkeypatch):
227227
os.makedirs(os.path.join(tmpdir, "remote_dir"), exist_ok=True)
228228
monkeypatch.setattr(cache_module, "_try_create_cache_dir", lambda name: os.path.join(tmpdir, name))
229229

230-
monkeypatch.setattr(cache_module, "_find_remote_dir", lambda name, _: (os.path.join(tmpdir, "remote_dir"), True))
230+
monkeypatch.setattr(
231+
cache_module, "_find_remote_dir", lambda name, _: (os.path.join(tmpdir, "remote_dir", "version_0"), True)
232+
)
231233
cache = Cache(name="something")
232234
assert cache._writer._chunk_size == 2
233-
assert cache._writer._cache_dir == os.path.join(tmpdir, "something")
234-
assert cache._reader._remote_dir == os.path.join(tmpdir, "remote_dir")
235+
assert cache._writer._cache_dir == os.path.join(tmpdir, "something", "version_0")
236+
assert cache._reader._remote_dir == os.path.join(tmpdir, "remote_dir", "version_0")
235237

236238

237239
def test_streaming_dataset(tmpdir, monkeypatch):

0 commit comments

Comments
 (0)