@@ -452,10 +452,10 @@ def __init__(self, loader: DataLoader) -> None:
452
452
453
453
distributed_env = _DistributedEnv .detect ()
454
454
455
- if self ._loader ._profile_bactches and distributed_env .global_rank == 0 and _VIZ_TRACKER_AVAILABLE :
455
+ if self ._loader ._profile_batches and distributed_env .global_rank == 0 and _VIZ_TRACKER_AVAILABLE :
456
456
from torch .utils .data ._utils import worker
457
457
458
- worker ._worker_loop = _ProfileWorkerLoop (self ._loader ._profile_bactches , self ._loader ._profile_dir )
458
+ worker ._worker_loop = _ProfileWorkerLoop (self ._loader ._profile_batches , self ._loader ._profile_dir )
459
459
460
460
super ().__init__ (loader )
461
461
@@ -479,8 +479,56 @@ def _try_put_index(self) -> None:
479
479
480
480
481
481
class StreamingDataLoader (DataLoader ):
482
- """The `StreamingDataLoader` keeps track of the number of samples fetched in order to enable resumability of the
483
- dataset."""
482
+ r"""The StreamingDataLoader combines a dataset and a sampler, and provides an iterable over the given dataset.
483
+
484
+ The :class:`~lightning.data.streaming.dataloader.StreamingDataLoader` supports either a
485
+ StreamingDataset and CombinedStreamingDataset datasets with single- or multi-process loading,
486
+ customizing
487
+ loading order and optional automatic batching (collation) and memory pinning.
488
+
489
+ See :py:mod:`torch.utils.data` documentation page for more details.
490
+
491
+ Args:
492
+ dataset (Dataset): dataset from which to load the data.
493
+ batch_size (int, optional): how many samples per batch to load
494
+ (default: ``1``).
495
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
496
+ at every epoch (default: ``False``).
497
+ num_workers (int, optional): how many subprocesses to use for data
498
+ loading. ``0`` means that the data will be loaded in the main process.
499
+ (default: ``0``)
500
+ collate_fn (Callable, optional): merges a list of samples to form a
501
+ mini-batch of Tensor(s). Used when using batched loading from a
502
+ map-style dataset.
503
+ pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
504
+ into device/CUDA pinned memory before returning them. If your data elements
505
+ are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
506
+ see the example below.
507
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
508
+ from workers. Should always be non-negative. (default: ``0``)
509
+ worker_init_fn (Callable, optional): If not ``None``, this will be called on each
510
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
511
+ input, after seeding and before data loading. (default: ``None``)
512
+ multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
513
+ ``None``, the default `multiprocessing context`_ of your operating system will
514
+ be used. (default: ``None``)
515
+ generator (torch.Generator, optional): If not ``None``, this RNG will be used
516
+ by RandomSampler to generate random indexes and multiprocessing to generate
517
+ ``base_seed`` for workers. (default: ``None``)
518
+ prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
519
+ in advance by each worker. ``2`` means there will be a total of
520
+ 2 * num_workers batches prefetched across all workers. (default value depends
521
+ on the set value for num_workers. If value of num_workers=0 default is ``None``.
522
+ Otherwise, if value of ``num_workers > 0`` default is ``2``).
523
+ persistent_workers (bool, optional): If ``True``, the data loader will not shut down
524
+ the worker processes after a dataset has been consumed once. This allows to
525
+ maintain the workers `Dataset` instances alive. (default: ``False``)
526
+ pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
527
+ ``True``.
528
+ profile_batches (int, bool, optional): Whether to record data loading profile and generate a result.json file.
529
+ profile_dir (int, bool, optional): Where to store the recorded trace when profile_batches is enabled.
530
+
531
+ """
484
532
485
533
__doc__ = DataLoader .__doc__
486
534
@@ -490,7 +538,7 @@ def __init__(
490
538
* args : Any ,
491
539
batch_size : int = 1 ,
492
540
num_workers : int = 0 ,
493
- profile_bactches : Union [bool , int ] = False ,
541
+ profile_batches : Union [bool , int ] = False ,
494
542
profile_dir : Optional [str ] = None ,
495
543
prefetch_factor : Optional [int ] = None ,
496
544
** kwargs : Any ,
@@ -501,16 +549,16 @@ def __init__(
501
549
f" Found { dataset } ."
502
550
)
503
551
504
- if profile_bactches and not _VIZ_TRACKER_AVAILABLE :
505
- raise ModuleNotFoundError ("To use profile_bactches , viztracer is required. Run `pip install viztracer`" )
552
+ if profile_batches and not _VIZ_TRACKER_AVAILABLE :
553
+ raise ModuleNotFoundError ("To use profile_batches , viztracer is required. Run `pip install viztracer`" )
506
554
507
- if profile_bactches and num_workers == 0 :
555
+ if profile_batches and num_workers == 0 :
508
556
raise ValueError ("Profiling is supported only with num_workers >= 1." )
509
557
510
558
self .current_epoch = 0
511
559
self .batch_size = batch_size
512
560
self .num_workers = num_workers
513
- self ._profile_bactches = profile_bactches
561
+ self ._profile_batches = profile_batches
514
562
self ._profile_dir = profile_dir
515
563
self ._num_samples_yielded_streaming = 0
516
564
self ._num_samples_yielded_combined : Dict [int , List [Any ]] = {}
0 commit comments