16
16
import logging
17
17
import os
18
18
from importlib import reload
19
- from typing import Any , Callable , List , Optional
19
+ from typing import Any , Callable , Dict , List , Optional , Union
20
20
21
21
import torch
22
22
from torch .utils .data import Dataset , IterableDataset
32
32
from torch .utils .data .sampler import BatchSampler , Sampler
33
33
34
34
from lightning .data .streaming import Cache
35
+ from lightning .data .streaming .combined import CombinedStreamingDataset
35
36
from lightning .data .streaming .constants import _DEFAULT_CHUNK_BYTES , _TORCH_GREATER_EQUAL_2_1_0 , _VIZ_TRACKER_AVAILABLE
37
+ from lightning .data .streaming .dataset import StreamingDataset
36
38
from lightning .data .streaming .sampler import CacheBatchSampler
37
39
from lightning .data .utilities .env import _DistributedEnv
38
40
@@ -248,7 +250,7 @@ def _next_data(self) -> Any:
248
250
raise e
249
251
250
252
251
- class StreamingDataLoader (DataLoader ):
253
+ class CacheDataLoader (DataLoader ):
252
254
__doc__ = DataLoader .__doc__
253
255
254
256
def __init__ (
@@ -271,16 +273,16 @@ def __init__(
271
273
) -> None :
272
274
if sampler :
273
275
raise ValueError (
274
- "The StreamingDataLoader relies on its own internal sampler. Passing a sampler isn't supported."
276
+ "The CacheDataLoader relies on its own internal sampler. Passing a sampler isn't supported."
275
277
)
276
278
277
279
if batch_sampler :
278
280
raise ValueError (
279
- "The StreamingDataLoader relies on its own internal sampler. Passing a batch_sampler isn't supported."
281
+ "The CacheDataLoader relies on its own internal sampler. Passing a batch_sampler isn't supported."
280
282
)
281
283
282
284
if isinstance (dataset , IterableDataset ):
283
- raise ValueError ("Only map-based dataset are supported by the StreamingDataLoader for now." )
285
+ raise ValueError ("Only map-based dataset are supported by the CacheDataLoader for now." )
284
286
285
287
if profile and not _VIZ_TRACKER_AVAILABLE :
286
288
raise ModuleNotFoundError ("To enable DataLoader profiling, run `pip install viztracer`." )
@@ -294,7 +296,7 @@ def __init__(
294
296
295
297
if len (cache_list ) == 0 :
296
298
if cache_dir is None :
297
- raise ValueError ("You should provide a `cache_dir` filepath to the StreamingDataLoader ." )
299
+ raise ValueError ("You should provide a `cache_dir` filepath to the CacheDataLoader ." )
298
300
299
301
dataset = CacheDataset (dataset , cache_dir , chunk_bytes , batch_size , compression )
300
302
cache = dataset ._cache
@@ -337,3 +339,55 @@ def _get_iterator(self) -> "_BaseDataLoaderIter":
337
339
return _SingleProcessDataLoaderIterPatch (self )
338
340
self .check_worker_number_rationality ()
339
341
return _MultiProcessingDataLoaderIterPatch (self )
342
+
343
+
344
+ class StreamingDataLoader (DataLoader ):
345
+ """The `StreamingDataLoader` keeps track of the number of samples fetched in order to enable resumability of the
346
+ dataset."""
347
+
348
+ __doc__ = DataLoader .__doc__
349
+
350
+ def __init__ (
351
+ self ,
352
+ dataset : Union [StreamingDataset , CombinedStreamingDataset ],
353
+ * args : Any ,
354
+ batch_size : int = 1 ,
355
+ num_workers : int = 0 ,
356
+ ** kwargs : Any ,
357
+ ) -> None : # pyright: ignore
358
+ self .batch_size = batch_size
359
+ self .num_workers = num_workers
360
+ self .num_samples_yielded = 0
361
+ super ().__init__ (dataset , * args , batch_size = batch_size , num_workers = num_workers , ** kwargs ) # type: ignore
362
+
363
+ def __iter__ (self ) -> Any :
364
+ if isinstance (self .dataset , StreamingDataset ):
365
+ assert self .batch_size
366
+ self .num_samples_yielded = 0
367
+ for batch in super ().__iter__ ():
368
+ self .num_samples_yielded += self .batch_size
369
+ yield batch
370
+ else :
371
+ yield from super ().__iter__ ()
372
+
373
+ def state_dict (self ) -> Optional [Dict [str , Any ]]:
374
+ if isinstance (self .dataset , StreamingDataset ):
375
+ assert self .batch_size
376
+ env = _DistributedEnv .detect ()
377
+ num_samples = self .num_samples_yielded * env .world_size
378
+ return self .dataset .state_dict (num_samples , self .num_workers , self .batch_size )
379
+ return self .dataset .state_dict (self .num_workers , self .batch_size )
380
+
381
+ def load_state_dict (self , obj : Dict [str , Any ]) -> None :
382
+ """Load a dict containing training state (called from non-worker process).
383
+
384
+ This is called on each copy of the dataset when resuming.
385
+
386
+ Args:
387
+ obj (Dict[str, Any]): The state.
388
+
389
+ """
390
+ if isinstance (self .dataset , (StreamingDataset , CombinedStreamingDataset )):
391
+ self .dataset .load_state_dict (obj )
392
+ else :
393
+ raise RuntimeError ("The provided dataset should be a `StreamingDataset` or a `CombinedStreamingDataset`." )
0 commit comments