Skip to content

Commit 28b3806

Browse files
authored
StreamingDataloader: Resolve typo (#19370)
1 parent 322f474 commit 28b3806

File tree

2 files changed

+58
-10
lines changed

2 files changed

+58
-10
lines changed

src/lightning/data/streaming/dataloader.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -452,10 +452,10 @@ def __init__(self, loader: DataLoader) -> None:
452452

453453
distributed_env = _DistributedEnv.detect()
454454

455-
if self._loader._profile_bactches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE:
455+
if self._loader._profile_batches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE:
456456
from torch.utils.data._utils import worker
457457

458-
worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_bactches, self._loader._profile_dir)
458+
worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_batches, self._loader._profile_dir)
459459

460460
super().__init__(loader)
461461

@@ -479,8 +479,56 @@ def _try_put_index(self) -> None:
479479

480480

481481
class StreamingDataLoader(DataLoader):
482-
"""The `StreamingDataLoader` keeps track of the number of samples fetched in order to enable resumability of the
483-
dataset."""
482+
r"""The StreamingDataLoader combines a dataset and a sampler, and provides an iterable over the given dataset.
483+
484+
The :class:`~lightning.data.streaming.dataloader.StreamingDataLoader` supports either a
485+
StreamingDataset and CombinedStreamingDataset datasets with single- or multi-process loading,
486+
customizing
487+
loading order and optional automatic batching (collation) and memory pinning.
488+
489+
See :py:mod:`torch.utils.data` documentation page for more details.
490+
491+
Args:
492+
dataset (Dataset): dataset from which to load the data.
493+
batch_size (int, optional): how many samples per batch to load
494+
(default: ``1``).
495+
shuffle (bool, optional): set to ``True`` to have the data reshuffled
496+
at every epoch (default: ``False``).
497+
num_workers (int, optional): how many subprocesses to use for data
498+
loading. ``0`` means that the data will be loaded in the main process.
499+
(default: ``0``)
500+
collate_fn (Callable, optional): merges a list of samples to form a
501+
mini-batch of Tensor(s). Used when using batched loading from a
502+
map-style dataset.
503+
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
504+
into device/CUDA pinned memory before returning them. If your data elements
505+
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
506+
see the example below.
507+
timeout (numeric, optional): if positive, the timeout value for collecting a batch
508+
from workers. Should always be non-negative. (default: ``0``)
509+
worker_init_fn (Callable, optional): If not ``None``, this will be called on each
510+
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
511+
input, after seeding and before data loading. (default: ``None``)
512+
multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
513+
``None``, the default `multiprocessing context`_ of your operating system will
514+
be used. (default: ``None``)
515+
generator (torch.Generator, optional): If not ``None``, this RNG will be used
516+
by RandomSampler to generate random indexes and multiprocessing to generate
517+
``base_seed`` for workers. (default: ``None``)
518+
prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
519+
in advance by each worker. ``2`` means there will be a total of
520+
2 * num_workers batches prefetched across all workers. (default value depends
521+
on the set value for num_workers. If value of num_workers=0 default is ``None``.
522+
Otherwise, if value of ``num_workers > 0`` default is ``2``).
523+
persistent_workers (bool, optional): If ``True``, the data loader will not shut down
524+
the worker processes after a dataset has been consumed once. This allows to
525+
maintain the workers `Dataset` instances alive. (default: ``False``)
526+
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
527+
``True``.
528+
profile_batches (int, bool, optional): Whether to record data loading profile and generate a result.json file.
529+
profile_dir (int, bool, optional): Where to store the recorded trace when profile_batches is enabled.
530+
531+
"""
484532

485533
__doc__ = DataLoader.__doc__
486534

@@ -490,7 +538,7 @@ def __init__(
490538
*args: Any,
491539
batch_size: int = 1,
492540
num_workers: int = 0,
493-
profile_bactches: Union[bool, int] = False,
541+
profile_batches: Union[bool, int] = False,
494542
profile_dir: Optional[str] = None,
495543
prefetch_factor: Optional[int] = None,
496544
**kwargs: Any,
@@ -501,16 +549,16 @@ def __init__(
501549
f" Found {dataset}."
502550
)
503551

504-
if profile_bactches and not _VIZ_TRACKER_AVAILABLE:
505-
raise ModuleNotFoundError("To use profile_bactches, viztracer is required. Run `pip install viztracer`")
552+
if profile_batches and not _VIZ_TRACKER_AVAILABLE:
553+
raise ModuleNotFoundError("To use profile_batches, viztracer is required. Run `pip install viztracer`")
506554

507-
if profile_bactches and num_workers == 0:
555+
if profile_batches and num_workers == 0:
508556
raise ValueError("Profiling is supported only with num_workers >= 1.")
509557

510558
self.current_epoch = 0
511559
self.batch_size = batch_size
512560
self.num_workers = num_workers
513-
self._profile_bactches = profile_bactches
561+
self._profile_batches = profile_batches
514562
self._profile_dir = profile_dir
515563
self._num_samples_yielded_streaming = 0
516564
self._num_samples_yielded_combined: Dict[int, List[Any]] = {}

tests/tests_data/streaming/test_dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_dataloader_profiling(profile, tmpdir, monkeypatch):
8484
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
8585
)
8686
dataloader = StreamingDataLoader(
87-
dataset, batch_size=2, profile_bactches=profile, profile_dir=str(tmpdir), num_workers=1
87+
dataset, batch_size=2, profile_batches=profile, profile_dir=str(tmpdir), num_workers=1
8888
)
8989
dataloader_iter = iter(dataloader)
9090
batches = []

0 commit comments

Comments
 (0)