Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 119 additions & 112 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,35 @@ Number of devices to train on (``int``), which devices to train on (``list`` or
# Training with GPU Accelerator using total number of gpus available on the system
Trainer(accelerator="gpu")


enable_autolog_hparams
^^^^^^^^^^^^^^^^^^^^^^

Whether to log hyperparameters at the start of a run. Defaults to True.

.. testcode::

# default used by the Trainer
trainer = Trainer(enable_autolog_hparams=True)

# disable logging hyperparams
trainer = Trainer(enable_autolog_hparams=False)

With the parameter set to false, you can add custom code to log hyperparameters.

.. code-block:: python

model = LitModel()
trainer = Trainer(enable_autolog_hparams=False)
for logger in trainer.loggers:
if isinstance(logger, lightning.pytorch.loggers.CSVLogger):
logger.log_hyperparams(hparams_dict_1)
else:
logger.log_hyperparams(hparams_dict_2)

You can also use `self.logger.log_hyperparams(...)` inside `LightningModule` to log.


enable_checkpointing
^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -443,6 +472,40 @@ See :doc:`Saving and Loading Checkpoints <../common/checkpointing>` for how to c
# Add your callback to the callbacks list
trainer = Trainer(callbacks=[checkpoint_callback])


enable_model_summary
^^^^^^^^^^^^^^^^^^^^

Whether to enable or disable the model summarization. Defaults to True.

.. testcode::

# default used by the Trainer
trainer = Trainer(enable_model_summary=True)

# disable summarization
trainer = Trainer(enable_model_summary=False)

# enable custom summarization
from lightning.pytorch.callbacks import ModelSummary

trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])


enable_progress_bar
^^^^^^^^^^^^^^^^^^^

Whether to enable or disable the progress bar. Defaults to True.

.. testcode::

# default used by the Trainer
trainer = Trainer(enable_progress_bar=True)

# disable progress bar
trainer = Trainer(enable_progress_bar=False)


fast_dev_run
^^^^^^^^^^^^

Expand Down Expand Up @@ -500,6 +563,39 @@ Gradient clipping value
# default used by the Trainer
trainer = Trainer(gradient_clip_val=None)


inference_mode
^^^^^^^^^^^^^^

Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` mode during evaluation
(``validate``/``test``/``predict``)

.. testcode::

# default used by the Trainer
trainer = Trainer(inference_mode=True)

# Use `torch.no_grad` instead
trainer = Trainer(inference_mode=False)


With :func:`torch.inference_mode` disabled, you can enable the grad of your model layers if required.

.. code-block:: python

class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
preds = self.layer1(batch)
with torch.enable_grad():
grad_preds = preds.requires_grad_()
preds2 = self.layer2(grad_preds)


model = LitModel()
trainer = Trainer(inference_mode=False)
trainer.validate(model)


limit_train_batches
^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -871,18 +967,6 @@ See the :doc:`profiler documentation <../tuning/profiler>` for more details.
# advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler()`
trainer = Trainer(profiler="advanced")

enable_progress_bar
^^^^^^^^^^^^^^^^^^^

Whether to enable or disable the progress bar. Defaults to True.

.. testcode::

# default used by the Trainer
trainer = Trainer(enable_progress_bar=True)

# disable progress bar
trainer = Trainer(enable_progress_bar=False)

reload_dataloaders_every_n_epochs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -917,28 +1001,6 @@ The pseudocode applies also to the ``val_dataloader``.

.. _replace-sampler-ddp:

use_distributed_sampler
^^^^^^^^^^^^^^^^^^^^^^^

See :paramref:`lightning.pytorch.trainer.Trainer.params.use_distributed_sampler`.

.. testcode::

# default used by the Trainer
trainer = Trainer(use_distributed_sampler=True)

By setting to False, you have to add your own distributed sampler:

.. code-block:: python

# in your LightningModule or LightningDataModule
def train_dataloader(self):
dataset = ...
# default used by the Trainer
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
return dataloader


strategy
^^^^^^^^
Expand Down Expand Up @@ -982,6 +1044,29 @@ Enable synchronization between batchnorm layers across all GPUs.
trainer = Trainer(sync_batchnorm=True)


use_distributed_sampler
^^^^^^^^^^^^^^^^^^^^^^^

See :paramref:`lightning.pytorch.trainer.Trainer.params.use_distributed_sampler`.

.. testcode::

# default used by the Trainer
trainer = Trainer(use_distributed_sampler=True)

By setting to False, you have to add your own distributed sampler:

.. code-block:: python

# in your LightningModule or LightningDataModule
def train_dataloader(self):
dataset = ...
# default used by the Trainer
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
return dataloader


val_check_interval
^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -1058,84 +1143,6 @@ Can specify as float, int, or a time-based duration.
# Total number of batches run
total_fit_batches = total_train_batches + total_val_batches


enable_model_summary
^^^^^^^^^^^^^^^^^^^^

Whether to enable or disable the model summarization. Defaults to True.

.. testcode::

# default used by the Trainer
trainer = Trainer(enable_model_summary=True)

# disable summarization
trainer = Trainer(enable_model_summary=False)

# enable custom summarization
from lightning.pytorch.callbacks import ModelSummary

trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])


inference_mode
^^^^^^^^^^^^^^

Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` mode during evaluation
(``validate``/``test``/``predict``)

.. testcode::

# default used by the Trainer
trainer = Trainer(inference_mode=True)

# Use `torch.no_grad` instead
trainer = Trainer(inference_mode=False)


With :func:`torch.inference_mode` disabled, you can enable the grad of your model layers if required.

.. code-block:: python

class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
preds = self.layer1(batch)
with torch.enable_grad():
grad_preds = preds.requires_grad_()
preds2 = self.layer2(grad_preds)


model = LitModel()
trainer = Trainer(inference_mode=False)
trainer.validate(model)

enable_autolog_hparams
^^^^^^^^^^^^^^^^^^^^^^

Whether to log hyperparameters at the start of a run. Defaults to True.

.. testcode::

# default used by the Trainer
trainer = Trainer(enable_autolog_hparams=True)

# disable logging hyperparams
trainer = Trainer(enable_autolog_hparams=False)

With the parameter set to false, you can add custom code to log hyperparameters.

.. code-block:: python

model = LitModel()
trainer = Trainer(enable_autolog_hparams=False)
for logger in trainer.loggers:
if isinstance(logger, lightning.pytorch.loggers.CSVLogger):
logger.log_hyperparams(hparams_dict_1)
else:
logger.log_hyperparams(hparams_dict_2)

You can also use `self.logger.log_hyperparams(...)` inside `LightningModule` to log.

-----

Trainer class API
Expand Down
Loading