Skip to content

Commit cb06f09

Browse files
tchatonawaelchlithomas
authored andcommitted
Improve Streaming Dataset API (#18882)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: thomas <[email protected]> (cherry picked from commit c1437cc)
1 parent a55e3d3 commit cb06f09

File tree

4 files changed

+125
-175
lines changed

4 files changed

+125
-175
lines changed

src/lightning/data/streaming/dataset.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from lightning.data.streaming import Cache
2121
from lightning.data.streaming.item_loader import BaseItemLoader
2222
from lightning.data.streaming.sampler import ChunkedIndex
23-
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle, TruncatedShuffle
23+
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
2424

2525

2626
class StreamingDataset(IterableDataset):
@@ -32,7 +32,8 @@ def __init__(
3232
version: Optional[Union[int, Literal["latest"]]] = "latest",
3333
cache_dir: Optional[str] = None,
3434
item_loader: Optional[BaseItemLoader] = None,
35-
shuffle: Union[bool, Literal["truncated", "full"]] = "truncated",
35+
shuffle: bool = False,
36+
drop_last: bool = False,
3637
seed: int = 42,
3738
) -> None:
3839
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
@@ -43,10 +44,15 @@ def __init__(
4344
cache_dir: The cache dir where the data would be stored.
4445
item_loader: The logic to load an item from a chunk.
4546
shuffle: Whether to shuffle the data.
47+
drop_last: If `True`, drops the last items to ensure that
48+
all processes/workers return the same amount of data.
4649
seed: Random seed for shuffling.
4750
4851
"""
4952
super().__init__()
53+
if not isinstance(shuffle, bool):
54+
raise ValueError(f"Shuffle should be a boolean. Found {shuffle}")
55+
5056
self.cache = Cache(name=name, version=version, cache_dir=cache_dir, item_loader=item_loader, chunk_bytes=1)
5157

5258
self.cache._reader._try_load_config()
@@ -56,18 +62,10 @@ def __init__(
5662

5763
self.distributed_env = _DistributedEnv.detect()
5864

59-
if isinstance(shuffle, bool):
60-
_shuffle = TruncatedShuffle(self.cache, seed) if shuffle else NoShuffle(self.cache, seed)
61-
62-
if isinstance(shuffle, str):
63-
if shuffle == "truncated":
64-
_shuffle = TruncatedShuffle(self.cache, seed)
65-
elif shuffle == "full":
66-
_shuffle = FullShuffle(self.cache, seed)
67-
else:
68-
raise ValueError(f"The provided shuffle doesn't exist. Found {shuffle}")
69-
70-
self.shuffle: Shuffle = _shuffle
65+
self.shuffle: Shuffle = (
66+
FullShuffle(self.cache, seed, drop_last) if shuffle else NoShuffle(self.cache, seed, drop_last)
67+
)
68+
self.drop_last = drop_last
7169
self.worker_env: Optional[_WorkerEnv] = None
7270
self.worker_chunks: List[int] = []
7371
self.worker_intervals: List[List[int]] = []
@@ -84,7 +82,7 @@ def __len__(self) -> int:
8482
return self.shuffle.get_len(self.distributed_env, self.current_epoch)
8583

8684
def __iter__(self) -> "StreamingDataset":
87-
chunks_per_replica, intervals_per_replica = self.shuffle.get_chunks_and_intervals_per_process(
85+
chunks_per_replica, intervals_per_replica = self.shuffle.get_chunks_and_intervals_per_ranks(
8886
self.distributed_env, self.current_epoch
8987
)
9088
current_chunks = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]

src/lightning/data/streaming/shuffle.py

Lines changed: 50 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,27 @@
2424
class Shuffle(ABC):
2525
"""Shuffle describe how to distribute chunked datasets across processes and workers."""
2626

27-
def __init__(self, cache: Cache, seed: int):
27+
def __init__(self, cache: Cache, seed: int, drop_last: bool):
2828
self.cache = cache
2929
self.seed = seed
30+
self.drop_last = drop_last
3031
self.random_state = None
3132

32-
@abstractmethod
33+
@lru_cache(maxsize=10)
3334
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
34-
pass
35+
_, intervals_per_ranks = self.get_chunks_and_intervals_per_ranks(distributed_env, current_epoch)
36+
37+
if self.drop_last:
38+
items_per_process = [
39+
sum((interval[-1] - interval[0]) for interval in intervals) for intervals in intervals_per_ranks
40+
]
41+
min_items_per_process = min(items_per_process)
42+
return min_items_per_process
43+
44+
return sum((interval[-1] - interval[0]) for interval in intervals_per_ranks[distributed_env.global_rank])
3545

3646
@abstractmethod
37-
def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
47+
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
3848
pass
3949

4050
@abstractmethod
@@ -43,79 +53,29 @@ def __call__(self, array: np.ndarray) -> List[int]:
4353

4454

4555
class NoShuffle(Shuffle):
46-
"""NoShuffle doesn't shuffle the items and ensure all the processes receive the same number of items."""
56+
"""NoShuffle doesn't shuffle the items and ensure all the processes receive the same number of items if drop_last
57+
is True."""
4758

4859
@lru_cache(maxsize=10)
49-
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
50-
_, intervals_per_process = self.get_chunks_and_intervals_per_process(distributed_env, current_epoch)
51-
min_items_per_process = min(
52-
[sum([(interval[-1] - interval[0]) for interval in intervals]) for intervals in intervals_per_process]
53-
)
54-
return min_items_per_process
55-
56-
@lru_cache(maxsize=10)
57-
def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
60+
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
5861
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
5962
chunk_intervals = self.cache.get_chunk_intervals()
6063
indexes = list(range(len(chunk_intervals)))
6164
shuffled_chunk_intervals = np.asarray(chunk_intervals)[indexes]
6265

63-
chunks_per_process: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
64-
intervals_per_process: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
66+
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
67+
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
6568
for index, (chunk_index, chunk_interval) in enumerate(zip(indexes, shuffled_chunk_intervals)):
6669
replica_index = index % distributed_env.world_size
67-
chunks_per_process[replica_index].append(chunk_index)
68-
intervals_per_process[replica_index].append(chunk_interval)
70+
chunks_per_ranks[replica_index].append(chunk_index)
71+
intervals_per_ranks[replica_index].append(chunk_interval)
6972

70-
return chunks_per_process, intervals_per_process
73+
return chunks_per_ranks, intervals_per_ranks
7174

7275
def __call__(self, array: np.ndarray) -> List[int]:
7376
return array.tolist()
7477

7578

76-
class TruncatedShuffle(Shuffle):
77-
"""TruncatedShuffle shuffles the chunks and associates them to the ranks.
78-
79-
As the number of items in a chunk varies, it is possible for a rank to end up with more or less items.
80-
81-
To ensure the same fixed dataset length for all ranks, we compute the minimum number of items across all ranks.
82-
83-
For the ranks with more items than the minimum, the remaining items are dropped.
84-
85-
Note: This is the fastest sampling strategy but at the cost of losing items.
86-
87-
"""
88-
89-
@lru_cache(maxsize=10)
90-
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
91-
_, intervals_per_process = self.get_chunks_and_intervals_per_process(distributed_env, current_epoch)
92-
min_items_per_process = min(
93-
[sum([(interval[-1] - interval[0]) for interval in intervals]) for intervals in intervals_per_process]
94-
)
95-
return min_items_per_process
96-
97-
@lru_cache(maxsize=10)
98-
def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
99-
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
100-
chunk_intervals = self.cache.get_chunk_intervals()
101-
indexes = range(len(chunk_intervals))
102-
shuffled_indexes = self.random_state.permutation(indexes)
103-
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
104-
105-
chunks_per_process: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
106-
intervals_per_process: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
107-
for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)):
108-
replica_index = index % distributed_env.world_size
109-
chunks_per_process[replica_index].append(chunk_index)
110-
intervals_per_process[replica_index].append(chunk_interval)
111-
112-
return chunks_per_process, intervals_per_process
113-
114-
def __call__(self, array: np.ndarray) -> List[int]:
115-
assert self.random_state
116-
return self.random_state.permutation(array).tolist()
117-
118-
11979
class FullShuffle(Shuffle):
12080
"""FullShuffle shuffles the chunks and associates them to the ranks.
12181
@@ -135,36 +95,40 @@ class FullShuffle(Shuffle):
13595
"""
13696

13797
@lru_cache(maxsize=10)
138-
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
139-
_, intervals_per_process = self.get_chunks_and_intervals_per_process(distributed_env, current_epoch)
140-
min_items_per_process = min([sum([(i[-1] - i[0]) for i in intervals]) for intervals in intervals_per_process])
141-
return min_items_per_process
142-
143-
@lru_cache(maxsize=10)
144-
def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
98+
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
14599
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
100+
101+
# 1. Get the intervals
146102
chunk_intervals = self.cache.get_chunk_intervals()
103+
104+
# 2. Shuffle them
147105
indexes = range(len(chunk_intervals))
148106
shuffled_indexes = self.random_state.permutation(indexes)
149107
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
150108

109+
# 3. Compute the items budget of each rank
151110
num_items = sum([(interval[-1] - interval[0]) for interval in chunk_intervals])
152-
num_items_per_process: List[int] = [
153-
num_items // distributed_env.world_size for _ in range(distributed_env.world_size)
111+
num_items_per_ranks: List[int] = [
112+
num_items // distributed_env.world_size + num_items % distributed_env.world_size
113+
if rank == distributed_env.world_size - 1 and not self.drop_last
114+
else num_items // distributed_env.world_size
115+
for rank in range(distributed_env.world_size)
154116
]
155-
chunks_per_process: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
156-
intervals_per_process: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
117+
chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
118+
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
119+
120+
# 4. Assign the chunk & intervals to each rank
157121
for chunk_index, chunk_interval in zip(shuffled_indexes, shuffled_chunk_intervals):
158-
process_index = 0
122+
rank = 0
159123

160124
while True:
161-
if process_index == len(num_items_per_process):
125+
if rank == len(num_items_per_ranks):
162126
break
163127

164-
items_left_to_assign = num_items_per_process[process_index]
128+
items_left_to_assign = num_items_per_ranks[rank]
165129

166130
if items_left_to_assign == 0:
167-
process_index += 1
131+
rank += 1
168132
continue
169133

170134
items_in_chunk = chunk_interval[-1] - chunk_interval[0]
@@ -173,19 +137,19 @@ def get_chunks_and_intervals_per_process(self, distributed_env: _DistributedEnv,
173137
break
174138

175139
if items_in_chunk > items_left_to_assign:
176-
chunks_per_process[process_index].append(chunk_index)
140+
chunks_per_ranks[rank].append(chunk_index)
177141
begin, end = chunk_interval
178-
intervals_per_process[process_index].append([begin, begin + items_left_to_assign])
179-
chunk_interval = (begin + items_left_to_assign + 1, end)
180-
num_items_per_process[process_index] = 0
181-
process_index += 1
142+
intervals_per_ranks[rank].append([begin, begin + items_left_to_assign])
143+
chunk_interval = (begin + items_left_to_assign, end)
144+
num_items_per_ranks[rank] = 0
145+
rank += 1
182146
else:
183-
chunks_per_process[process_index].append(chunk_index)
184-
intervals_per_process[process_index].append(chunk_interval)
185-
num_items_per_process[process_index] -= items_in_chunk
147+
chunks_per_ranks[rank].append(chunk_index)
148+
intervals_per_ranks[rank].append(chunk_interval)
149+
num_items_per_ranks[rank] -= items_in_chunk
186150
break
187151

188-
return chunks_per_process, intervals_per_process
152+
return chunks_per_ranks, intervals_per_ranks
189153

190154
def __call__(self, array: np.ndarray) -> List[int]:
191155
assert self.random_state

tests/tests_data/streaming/test_cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,15 @@ def test_cache_with_simple_format(tmpdir):
162162

163163
cache = Cache(cache_dir, chunk_bytes=90)
164164

165+
# you encode data
165166
for i in range(100):
166167
cache[i] = i
167168

169+
# I am done, write the index ...
168170
cache.done()
169171
cache.merge()
170172

173+
# please, decode the data for me.
171174
for i in range(100):
172175
assert i == cache[i]
173176

0 commit comments

Comments
 (0)