@@ -452,10 +452,10 @@ def __init__(self, loader: DataLoader) -> None:
452452
453453 distributed_env = _DistributedEnv .detect ()
454454
455- if self ._loader ._profile_bactches and distributed_env .global_rank == 0 and _VIZ_TRACKER_AVAILABLE :
455+ if self ._loader ._profile_batches and distributed_env .global_rank == 0 and _VIZ_TRACKER_AVAILABLE :
456456 from torch .utils .data ._utils import worker
457457
458- worker ._worker_loop = _ProfileWorkerLoop (self ._loader ._profile_bactches , self ._loader ._profile_dir )
458+ worker ._worker_loop = _ProfileWorkerLoop (self ._loader ._profile_batches , self ._loader ._profile_dir )
459459
460460 super ().__init__ (loader )
461461
@@ -479,8 +479,56 @@ def _try_put_index(self) -> None:
479479
480480
481481class StreamingDataLoader (DataLoader ):
482- """The `StreamingDataLoader` keeps track of the number of samples fetched in order to enable resumability of the
483- dataset."""
482+ r"""The StreamingDataLoader combines a dataset and a sampler, and provides an iterable over the given dataset.
483+
484+ The :class:`~lightning.data.streaming.dataloader.StreamingDataLoader` supports either a
485+ StreamingDataset and CombinedStreamingDataset datasets with single- or multi-process loading,
486+ customizing
487+ loading order and optional automatic batching (collation) and memory pinning.
488+
489+ See :py:mod:`torch.utils.data` documentation page for more details.
490+
491+ Args:
492+ dataset (Dataset): dataset from which to load the data.
493+ batch_size (int, optional): how many samples per batch to load
494+ (default: ``1``).
495+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
496+ at every epoch (default: ``False``).
497+ num_workers (int, optional): how many subprocesses to use for data
498+ loading. ``0`` means that the data will be loaded in the main process.
499+ (default: ``0``)
500+ collate_fn (Callable, optional): merges a list of samples to form a
501+ mini-batch of Tensor(s). Used when using batched loading from a
502+ map-style dataset.
503+ pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
504+ into device/CUDA pinned memory before returning them. If your data elements
505+ are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
506+ see the example below.
507+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
508+ from workers. Should always be non-negative. (default: ``0``)
509+ worker_init_fn (Callable, optional): If not ``None``, this will be called on each
510+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
511+ input, after seeding and before data loading. (default: ``None``)
512+ multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
513+ ``None``, the default `multiprocessing context`_ of your operating system will
514+ be used. (default: ``None``)
515+ generator (torch.Generator, optional): If not ``None``, this RNG will be used
516+ by RandomSampler to generate random indexes and multiprocessing to generate
517+ ``base_seed`` for workers. (default: ``None``)
518+ prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
519+ in advance by each worker. ``2`` means there will be a total of
520+ 2 * num_workers batches prefetched across all workers. (default value depends
521+ on the set value for num_workers. If value of num_workers=0 default is ``None``.
522+ Otherwise, if value of ``num_workers > 0`` default is ``2``).
523+ persistent_workers (bool, optional): If ``True``, the data loader will not shut down
524+ the worker processes after a dataset has been consumed once. This allows to
525+ maintain the workers `Dataset` instances alive. (default: ``False``)
526+ pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
527+ ``True``.
528+ profile_batches (int, bool, optional): Whether to record data loading profile and generate a result.json file.
529+ profile_dir (int, bool, optional): Where to store the recorded trace when profile_batches is enabled.
530+
531+ """
484532
485533 __doc__ = DataLoader .__doc__
486534
@@ -490,7 +538,7 @@ def __init__(
490538 * args : Any ,
491539 batch_size : int = 1 ,
492540 num_workers : int = 0 ,
493- profile_bactches : Union [bool , int ] = False ,
541+ profile_batches : Union [bool , int ] = False ,
494542 profile_dir : Optional [str ] = None ,
495543 prefetch_factor : Optional [int ] = None ,
496544 ** kwargs : Any ,
@@ -501,16 +549,16 @@ def __init__(
501549 f" Found { dataset } ."
502550 )
503551
504- if profile_bactches and not _VIZ_TRACKER_AVAILABLE :
505- raise ModuleNotFoundError ("To use profile_bactches , viztracer is required. Run `pip install viztracer`" )
552+ if profile_batches and not _VIZ_TRACKER_AVAILABLE :
553+ raise ModuleNotFoundError ("To use profile_batches , viztracer is required. Run `pip install viztracer`" )
506554
507- if profile_bactches and num_workers == 0 :
555+ if profile_batches and num_workers == 0 :
508556 raise ValueError ("Profiling is supported only with num_workers >= 1." )
509557
510558 self .current_epoch = 0
511559 self .batch_size = batch_size
512560 self .num_workers = num_workers
513- self ._profile_bactches = profile_bactches
561+ self ._profile_batches = profile_batches
514562 self ._profile_dir = profile_dir
515563 self ._num_samples_yielded_streaming = 0
516564 self ._num_samples_yielded_combined : Dict [int , List [Any ]] = {}
0 commit comments