Skip to content

Commit 5fe9e93

Browse files
authored
Update dataloader docstrings (#17061)
1 parent 9cd131c commit 5fe9e93

File tree

2 files changed

+48
-129
lines changed

2 files changed

+48
-129
lines changed

src/lightning/pytorch/core/hooks.py

Lines changed: 15 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,9 @@ def teardown(self, stage: str) -> None:
388388
"""
389389

390390
def train_dataloader(self) -> TRAIN_DATALOADERS:
391-
"""Implement one or more PyTorch DataLoaders for training.
391+
"""An iterable or collection of iterables specifying training samples.
392392
393-
Return:
394-
A collection of :class:`torch.utils.data.DataLoader` specifying training samples.
395-
In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
393+
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
396394
397395
The dataloader you return will not be reloaded unless you set
398396
:paramref:`~lightning.pytorch.trainer.Trainer.reload_dataloaders_every_n_epochs` to
@@ -412,55 +410,15 @@ def train_dataloader(self) -> TRAIN_DATALOADERS:
412410
- :meth:`setup`
413411
414412
Note:
415-
Lightning adds the correct sampler for distributed and arbitrary hardware.
413+
Lightning tries to add the correct sampler for distributed and arbitrary hardware.
416414
There is no need to set it yourself.
417-
418-
Example::
419-
420-
# single dataloader
421-
def train_dataloader(self):
422-
transform = transforms.Compose([transforms.ToTensor(),
423-
transforms.Normalize((0.5,), (1.0,))])
424-
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
425-
download=True)
426-
loader = torch.utils.data.DataLoader(
427-
dataset=dataset,
428-
batch_size=self.batch_size,
429-
shuffle=True
430-
)
431-
return loader
432-
433-
# multiple dataloaders, return as list
434-
def train_dataloader(self):
435-
mnist = MNIST(...)
436-
cifar = CIFAR(...)
437-
mnist_loader = torch.utils.data.DataLoader(
438-
dataset=mnist, batch_size=self.batch_size, shuffle=True
439-
)
440-
cifar_loader = torch.utils.data.DataLoader(
441-
dataset=cifar, batch_size=self.batch_size, shuffle=True
442-
)
443-
# each batch will be a list of tensors: [batch_mnist, batch_cifar]
444-
return [mnist_loader, cifar_loader]
445-
446-
# multiple dataloader, return as dict
447-
def train_dataloader(self):
448-
mnist = MNIST(...)
449-
cifar = CIFAR(...)
450-
mnist_loader = torch.utils.data.DataLoader(
451-
dataset=mnist, batch_size=self.batch_size, shuffle=True
452-
)
453-
cifar_loader = torch.utils.data.DataLoader(
454-
dataset=cifar, batch_size=self.batch_size, shuffle=True
455-
)
456-
# each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
457-
return {'mnist': mnist_loader, 'cifar': cifar_loader}
458415
"""
459416
raise MisconfigurationException("`train_dataloader` must be implemented to be used with the Lightning Trainer")
460417

461418
def test_dataloader(self) -> EVAL_DATALOADERS:
462-
r"""
463-
Implement one or multiple PyTorch DataLoaders for testing.
419+
r"""An iterable or collection of iterables specifying test samples.
420+
421+
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
464422
465423
For data processing use the following pattern:
466424
@@ -477,44 +435,19 @@ def test_dataloader(self) -> EVAL_DATALOADERS:
477435
- :meth:`setup`
478436
479437
Note:
480-
Lightning adds the correct sampler for distributed and arbitrary hardware.
438+
Lightning tries to add the correct sampler for distributed and arbitrary hardware.
481439
There is no need to set it yourself.
482440
483-
Return:
484-
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying testing samples.
485-
486-
Example::
487-
488-
def test_dataloader(self):
489-
transform = transforms.Compose([transforms.ToTensor(),
490-
transforms.Normalize((0.5,), (1.0,))])
491-
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
492-
download=True)
493-
loader = torch.utils.data.DataLoader(
494-
dataset=dataset,
495-
batch_size=self.batch_size,
496-
shuffle=False
497-
)
498-
499-
return loader
500-
501-
# can also return multiple dataloaders
502-
def test_dataloader(self):
503-
return [loader_a, loader_b, ..., loader_n]
504-
505441
Note:
506442
If you don't need a test dataset and a :meth:`test_step`, you don't need to implement
507443
this method.
508-
509-
Note:
510-
In the case where you return multiple test dataloaders, the :meth:`test_step`
511-
will have an argument ``dataloader_idx`` which matches the order here.
512444
"""
513445
raise MisconfigurationException("`test_dataloader` must be implemented to be used with the Lightning Trainer")
514446

515447
def val_dataloader(self) -> EVAL_DATALOADERS:
516-
r"""
517-
Implement one or multiple PyTorch DataLoaders for validation.
448+
r"""An iterable or collection of iterables specifying validation samples.
449+
450+
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
518451
519452
The dataloader you return will not be reloaded unless you set
520453
:paramref:`~lightning.pytorch.trainer.Trainer.reload_dataloaders_every_n_epochs` to
@@ -528,44 +461,19 @@ def val_dataloader(self) -> EVAL_DATALOADERS:
528461
- :meth:`setup`
529462
530463
Note:
531-
Lightning adds the correct sampler for distributed and arbitrary hardware
464+
Lightning tries to add the correct sampler for distributed and arbitrary hardware
532465
There is no need to set it yourself.
533466
534-
Return:
535-
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
536-
537-
Examples::
538-
539-
def val_dataloader(self):
540-
transform = transforms.Compose([transforms.ToTensor(),
541-
transforms.Normalize((0.5,), (1.0,))])
542-
dataset = MNIST(root='/path/to/mnist/', train=False,
543-
transform=transform, download=True)
544-
loader = torch.utils.data.DataLoader(
545-
dataset=dataset,
546-
batch_size=self.batch_size,
547-
shuffle=False
548-
)
549-
550-
return loader
551-
552-
# can also return multiple dataloaders
553-
def val_dataloader(self):
554-
return [loader_a, loader_b, ..., loader_n]
555-
556467
Note:
557468
If you don't need a validation dataset and a :meth:`validation_step`, you don't need to
558469
implement this method.
559-
560-
Note:
561-
In the case where you return multiple validation dataloaders, the :meth:`validation_step`
562-
will have an argument ``dataloader_idx`` which matches the order here.
563470
"""
564471
raise MisconfigurationException("`val_dataloader` must be implemented to be used with the Lightning Trainer")
565472

566473
def predict_dataloader(self) -> EVAL_DATALOADERS:
567-
r"""
568-
Implement one or multiple PyTorch DataLoaders for prediction.
474+
r"""An iterable or collection of iterables specifying prediction samples.
475+
476+
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
569477
570478
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
571479
@@ -574,15 +482,11 @@ def predict_dataloader(self) -> EVAL_DATALOADERS:
574482
- :meth:`setup`
575483
576484
Note:
577-
Lightning adds the correct sampler for distributed and arbitrary hardware
485+
Lightning tries to add the correct sampler for distributed and arbitrary hardware
578486
There is no need to set it yourself.
579487
580488
Return:
581489
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples.
582-
583-
Note:
584-
In the case where you return multiple prediction dataloaders, the :meth:`predict_step`
585-
will have an argument ``dataloader_idx`` which matches the order here.
586490
"""
587491
raise MisconfigurationException(
588492
"`predict_dataloader` must be implemented to be used with the Lightning Trainer"

src/lightning/pytorch/trainer/trainer.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -500,17 +500,20 @@ def fit(
500500
Args:
501501
model: Model to fit.
502502
503-
train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
504-
:class:`~lightning.pytorch.core.datamodule.LightningDataModule` specifying training samples.
505-
In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
503+
train_dataloaders: An iterable or collection of iterables specifying training samples.
504+
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
505+
the `:class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook.
506506
507-
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
507+
val_dataloaders: An iterable or collection of iterables specifying validation samples.
508508
509509
ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special
510510
keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised.
511511
If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch.
512512
513-
datamodule: An instance of :class:`~lightning.pytorch.core.datamodule.LightningDataModule`.
513+
datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
514+
the `:class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook.
515+
516+
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
514517
"""
515518
model = _maybe_unwrap_optimized(model)
516519
self.strategy._lightning_module = model
@@ -573,8 +576,9 @@ def validate(
573576
Args:
574577
model: The model to validate.
575578
576-
dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them,
577-
or a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` specifying validation samples.
579+
dataloaders: An iterable or collection of iterables specifying validation samples.
580+
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
581+
the `:class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
578582
579583
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate.
580584
If ``None`` and the model instance was passed, use the current weights.
@@ -583,7 +587,10 @@ def validate(
583587
584588
verbose: If True, prints the validation results.
585589
586-
datamodule: An instance of :class:`~lightning.pytorch.core.datamodule.LightningDataModule`.
590+
datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
591+
the `:class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
592+
593+
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
587594
588595
Returns:
589596
List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks
@@ -666,8 +673,9 @@ def test(
666673
Args:
667674
model: The model to test.
668675
669-
dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them,
670-
or a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` specifying test samples.
676+
dataloaders: An iterable or collection of iterables specifying test samples.
677+
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
678+
the `:class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
671679
672680
ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test.
673681
If ``None`` and the model instance was passed, use the current weights.
@@ -676,7 +684,10 @@ def test(
676684
677685
verbose: If True, prints the test results.
678686
679-
datamodule: An instance of :class:`~lightning.pytorch.core.datamodule.LightningDataModule`.
687+
datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
688+
the `:class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
689+
690+
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
680691
681692
Returns:
682693
List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks
@@ -760,10 +771,12 @@ def predict(
760771
Args:
761772
model: The model to predict with.
762773
763-
dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them,
764-
or a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` specifying prediction samples.
774+
dataloaders: An iterable or collection of iterables specifying predict samples.
775+
Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
776+
the `:class:`~lightning.pytorch.core.hooks.DataHooks.predict_dataloader` hook.
765777
766-
datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders.
778+
datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
779+
the `:class:`~lightning.pytorch.core.hooks.DataHooks.predict_dataloader` hook.
767780
768781
return_predictions: Whether to return predictions.
769782
``True`` by default except when an accelerator that spawns processes is used (not supported).
@@ -773,6 +786,8 @@ def predict(
773786
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
774787
if a checkpoint callback is configured.
775788
789+
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.
790+
776791
Returns:
777792
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
778793
@@ -1360,27 +1375,27 @@ def is_last_batch(self) -> bool:
13601375
return self.fit_loop.epoch_loop.batch_progress.is_last_batch
13611376

13621377
@property
1363-
def train_dataloader(self) -> TRAIN_DATALOADERS:
1378+
def train_dataloader(self) -> Optional[TRAIN_DATALOADERS]:
13641379
"""The training dataloader(s) used during ``trainer.fit()``."""
13651380
if (combined_loader := self.fit_loop._combined_loader) is not None:
13661381
return combined_loader.iterables
13671382

13681383
@property
1369-
def val_dataloaders(self) -> EVAL_DATALOADERS:
1384+
def val_dataloaders(self) -> Optional[EVAL_DATALOADERS]:
13701385
"""The validation dataloader(s) used during ``trainer.fit()`` or ``trainer.validate()``."""
13711386
if (combined_loader := self.fit_loop.epoch_loop.val_loop._combined_loader) is not None:
13721387
return combined_loader.iterables
13731388
elif (combined_loader := self.validate_loop._combined_loader) is not None:
13741389
return combined_loader.iterables
13751390

13761391
@property
1377-
def test_dataloaders(self) -> EVAL_DATALOADERS:
1392+
def test_dataloaders(self) -> Optional[EVAL_DATALOADERS]:
13781393
"""The test dataloader(s) used during ``trainer.test()``."""
13791394
if (combined_loader := self.test_loop._combined_loader) is not None:
13801395
return combined_loader.iterables
13811396

13821397
@property
1383-
def predict_dataloaders(self) -> EVAL_DATALOADERS:
1398+
def predict_dataloaders(self) -> Optional[EVAL_DATALOADERS]:
13841399
"""The prediction dataloader(s) used during ``trainer.predict()``."""
13851400
if (combined_loader := self.predict_loop._combined_loader) is not None:
13861401
return combined_loader.iterables

0 commit comments

Comments
 (0)