Skip to content

Commit c989a97

Browse files
tchatonthomas
andauthored
feat(fr) StreamingDataset: Fault Tolerance v2 1/n (#19196)
Co-authored-by: thomas <[email protected]>
1 parent 9e159e1 commit c989a97

File tree

7 files changed

+383
-13
lines changed

7 files changed

+383
-13
lines changed

src/lightning/data/streaming/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,18 @@
1212
# limitations under the License.
1313

1414
from lightning.data.streaming.cache import Cache
15+
from lightning.data.streaming.combined import CombinedStreamingDataset
1516
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
17+
from lightning.data.streaming.dataloader import StreamingDataLoader
1618
from lightning.data.streaming.dataset import StreamingDataset
1719
from lightning.data.streaming.item_loader import TokensLoader
1820

1921
__all__ = [
2022
"Cache",
2123
"DataProcessor",
2224
"StreamingDataset",
25+
"CombinedStreamingDataset",
26+
"StreamingDataLoader",
2327
"DataTransformRecipe",
2428
"DataChunkRecipe",
2529
"TokensLoader",
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import random
15+
from typing import Any, Dict, Iterator, List, Optional, Sequence
16+
17+
from torch.utils.data import IterableDataset
18+
19+
from lightning.data.streaming.dataset import StreamingDataset
20+
21+
22+
class CombinedStreamingDataset(IterableDataset):
23+
"""The `CombinedStreamingDataset` enables to stream data from multiple StreamingDataset with the sampling ratio of
24+
your choice.
25+
26+
Addtionally, the `CombinedStreamingDataset` keeps track of the number of
27+
samples fetched to enable resumability of the datasets.
28+
29+
"""
30+
31+
def __init__(
32+
self, datasets: List[StreamingDataset], seed: int = 42, weights: Optional[Sequence[float]] = None
33+
) -> None:
34+
self._seed = seed
35+
self._datasets = datasets
36+
self._weights = weights
37+
num_datasets = len(datasets)
38+
39+
if weights is None:
40+
# Inversely weighted based on length
41+
self._weights = [1 / float(num_datasets)] * num_datasets
42+
else:
43+
self._weights = [w / sum(weights) for w in weights]
44+
45+
self._iterator: Optional[_CombinedDatasetIterator] = None
46+
47+
def __iter__(self) -> Iterator[Any]:
48+
assert self._weights
49+
self._iterator = _CombinedDatasetIterator(self._datasets, self._seed, self._weights)
50+
return self._iterator
51+
52+
def state_dict(self, num_workers: int, batch_size: int) -> Dict[str, Any]:
53+
if self._iterator is None:
54+
return {}
55+
return self._iterator.state_dict(num_workers, batch_size)
56+
57+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
58+
if len(state_dict) != len(self._datasets):
59+
raise RuntimeError(f"The provided state doesn't match the current number of datasets: {self._datasets}.")
60+
61+
for dataset_idx, dataset in enumerate(self._datasets):
62+
if str(dataset_idx) not in state_dict:
63+
raise RuntimeError(f"The provided state doesn't contain the index {dataset_idx}.")
64+
65+
dataset.load_state_dict(state_dict[str(dataset_idx)])
66+
67+
68+
class _CombinedDatasetIterator(Iterator):
69+
def __init__(self, datasets: List[StreamingDataset], seed: int, weights: Sequence[float]) -> None:
70+
self._datasets = datasets
71+
self._dataset_iters = [iter(dataset) for dataset in datasets]
72+
self._dataset_indexes = list(range(len(datasets)))
73+
self._num_samples_yielded = [0 for _ in range(len(datasets))]
74+
self._weights = weights
75+
self._rng = random.Random(seed)
76+
77+
def __next__(self) -> Any:
78+
# randomly select a dataset index
79+
(dataset_index,) = self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)
80+
81+
# keep track the sample was fetched
82+
self._num_samples_yielded[dataset_index] += 1
83+
84+
# return a new sample
85+
return next(self._dataset_iters[dataset_index])
86+
87+
def state_dict(self, num_workers: int = 0, batch_size: int = 1) -> Dict[str, Any]:
88+
return {
89+
str(dataset_idx): dataset.state_dict(self._num_samples_yielded[dataset_idx], num_workers, batch_size)
90+
for dataset_idx, dataset in enumerate(self._datasets)
91+
}

src/lightning/data/streaming/dataloader.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import logging
1717
import os
1818
from importlib import reload
19-
from typing import Any, Callable, List, Optional
19+
from typing import Any, Callable, Dict, List, Optional, Union
2020

2121
import torch
2222
from torch.utils.data import Dataset, IterableDataset
@@ -32,7 +32,9 @@
3232
from torch.utils.data.sampler import BatchSampler, Sampler
3333

3434
from lightning.data.streaming import Cache
35+
from lightning.data.streaming.combined import CombinedStreamingDataset
3536
from lightning.data.streaming.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE
37+
from lightning.data.streaming.dataset import StreamingDataset
3638
from lightning.data.streaming.sampler import CacheBatchSampler
3739
from lightning.data.utilities.env import _DistributedEnv
3840

@@ -248,7 +250,7 @@ def _next_data(self) -> Any:
248250
raise e
249251

250252

251-
class StreamingDataLoader(DataLoader):
253+
class CacheDataLoader(DataLoader):
252254
__doc__ = DataLoader.__doc__
253255

254256
def __init__(
@@ -271,16 +273,16 @@ def __init__(
271273
) -> None:
272274
if sampler:
273275
raise ValueError(
274-
"The StreamingDataLoader relies on its own internal sampler. Passing a sampler isn't supported."
276+
"The CacheDataLoader relies on its own internal sampler. Passing a sampler isn't supported."
275277
)
276278

277279
if batch_sampler:
278280
raise ValueError(
279-
"The StreamingDataLoader relies on its own internal sampler. Passing a batch_sampler isn't supported."
281+
"The CacheDataLoader relies on its own internal sampler. Passing a batch_sampler isn't supported."
280282
)
281283

282284
if isinstance(dataset, IterableDataset):
283-
raise ValueError("Only map-based dataset are supported by the StreamingDataLoader for now.")
285+
raise ValueError("Only map-based dataset are supported by the CacheDataLoader for now.")
284286

285287
if profile and not _VIZ_TRACKER_AVAILABLE:
286288
raise ModuleNotFoundError("To enable DataLoader profiling, run `pip install viztracer`.")
@@ -294,7 +296,7 @@ def __init__(
294296

295297
if len(cache_list) == 0:
296298
if cache_dir is None:
297-
raise ValueError("You should provide a `cache_dir` filepath to the StreamingDataLoader.")
299+
raise ValueError("You should provide a `cache_dir` filepath to the CacheDataLoader.")
298300

299301
dataset = CacheDataset(dataset, cache_dir, chunk_bytes, batch_size, compression)
300302
cache = dataset._cache
@@ -337,3 +339,55 @@ def _get_iterator(self) -> "_BaseDataLoaderIter":
337339
return _SingleProcessDataLoaderIterPatch(self)
338340
self.check_worker_number_rationality()
339341
return _MultiProcessingDataLoaderIterPatch(self)
342+
343+
344+
class StreamingDataLoader(DataLoader):
345+
"""The `StreamingDataLoader` keeps track of the number of samples fetched in order to enable resumability of the
346+
dataset."""
347+
348+
__doc__ = DataLoader.__doc__
349+
350+
def __init__(
351+
self,
352+
dataset: Union[StreamingDataset, CombinedStreamingDataset],
353+
*args: Any,
354+
batch_size: int = 1,
355+
num_workers: int = 0,
356+
**kwargs: Any,
357+
) -> None: # pyright: ignore
358+
self.batch_size = batch_size
359+
self.num_workers = num_workers
360+
self.num_samples_yielded = 0
361+
super().__init__(dataset, *args, batch_size=batch_size, num_workers=num_workers, **kwargs) # type: ignore
362+
363+
def __iter__(self) -> Any:
364+
if isinstance(self.dataset, StreamingDataset):
365+
assert self.batch_size
366+
self.num_samples_yielded = 0
367+
for batch in super().__iter__():
368+
self.num_samples_yielded += self.batch_size
369+
yield batch
370+
else:
371+
yield from super().__iter__()
372+
373+
def state_dict(self) -> Optional[Dict[str, Any]]:
374+
if isinstance(self.dataset, StreamingDataset):
375+
assert self.batch_size
376+
env = _DistributedEnv.detect()
377+
num_samples = self.num_samples_yielded * env.world_size
378+
return self.dataset.state_dict(num_samples, self.num_workers, self.batch_size)
379+
return self.dataset.state_dict(self.num_workers, self.batch_size)
380+
381+
def load_state_dict(self, obj: Dict[str, Any]) -> None:
382+
"""Load a dict containing training state (called from non-worker process).
383+
384+
This is called on each copy of the dataset when resuming.
385+
386+
Args:
387+
obj (Dict[str, Any]): The state.
388+
389+
"""
390+
if isinstance(self.dataset, (StreamingDataset, CombinedStreamingDataset)):
391+
self.dataset.load_state_dict(obj)
392+
else:
393+
raise RuntimeError("The provided dataset should be a `StreamingDataset` or a `CombinedStreamingDataset`.")

src/lightning/data/streaming/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def _checkpoint(self, chunk_index: int) -> None:
302302

303303
self.last_time = time()
304304

305-
def state_dict(self) -> Dict[str, Any]:
305+
def state_dict(self, num_samples_yielded: int = 0, num_workers: int = 0, batch_size: int = 1) -> Dict[str, Any]:
306306
if _is_in_dataloader_worker():
307307
raise RuntimeError("The method `state_dict` should only be called in the main process.")
308308

tests/tests_data/streaming/test_cache.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121
from lightning import seed_everything
2222
from lightning.data.streaming import Cache
23-
from lightning.data.streaming.dataloader import StreamingDataLoader
23+
from lightning.data.streaming.dataloader import CacheDataLoader
2424
from lightning.data.streaming.dataset import StreamingDataset
2525
from lightning.data.streaming.item_loader import TokensLoader
2626
from lightning.data.streaming.serializers import Serializer
@@ -72,7 +72,7 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None):
7272

7373
cache = Cache(cache_dir, chunk_size=10)
7474
dataset = ImageDataset(tmpdir, cache, dataset_size, 10)
75-
dataloader = StreamingDataLoader(dataset, num_workers=num_workers, batch_size=4)
75+
dataloader = CacheDataLoader(dataset, num_workers=num_workers, batch_size=4)
7676

7777
for _ in dataloader:
7878
pass
@@ -92,15 +92,15 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None):
9292

9393
if distributed_env.world_size == 1:
9494
indexes = []
95-
dataloader = StreamingDataLoader(dataset, num_workers=num_workers, batch_size=4)
95+
dataloader = CacheDataLoader(dataset, num_workers=num_workers, batch_size=4)
9696
for batch in dataloader:
9797
if batch:
9898
indexes.extend(batch["index"].numpy().tolist())
9999
assert len(indexes) == dataset_size
100100

101101
seed_everything(42)
102102

103-
dataloader = StreamingDataLoader(dataset, num_workers=num_workers, batch_size=4, shuffle=True)
103+
dataloader = CacheDataLoader(dataset, num_workers=num_workers, batch_size=4, shuffle=True)
104104
dataloader_iter = iter(dataloader)
105105

106106
indexes = []
@@ -194,7 +194,7 @@ def test_cache_with_auto_wrapping(tmpdir):
194194
os.makedirs(os.path.join(tmpdir, "cache_1"), exist_ok=True)
195195

196196
dataset = RandomDataset(64, 64)
197-
dataloader = StreamingDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_1"), chunk_bytes=2 << 12)
197+
dataloader = CacheDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_1"), chunk_bytes=2 << 12)
198198
for batch in dataloader:
199199
assert isinstance(batch, torch.Tensor)
200200
assert sorted(os.listdir(os.path.join(tmpdir, "cache_1"))) == [
@@ -217,7 +217,7 @@ def __len__(self) -> int:
217217

218218
os.makedirs(os.path.join(tmpdir, "cache_2"), exist_ok=True)
219219
dataset = RandomDatasetAtRuntime(64, 64)
220-
dataloader = StreamingDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_2"), chunk_bytes=2 << 12)
220+
dataloader = CacheDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_2"), chunk_bytes=2 << 12)
221221
with pytest.raises(ValueError, match="Your dataset items aren't deterministic"):
222222
for batch in dataloader:
223223
pass

0 commit comments

Comments
 (0)