From aac48245168a6a96d1a9b9dcb3fad5bab2789622 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 12 Sep 2025 08:40:36 +0200 Subject: [PATCH 1/2] reorder args --- docs/source-pytorch/common/trainer.rst | 166 +++++++++++++------------ 1 file changed, 86 insertions(+), 80 deletions(-) diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index a3bdb6bb7b2de..6eca504798747 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -413,6 +413,35 @@ Number of devices to train on (``int``), which devices to train on (``list`` or # Training with GPU Accelerator using total number of gpus available on the system Trainer(accelerator="gpu") + +enable_autolog_hparams +^^^^^^^^^^^^^^^^^^^^^^ + +Whether to log hyperparameters at the start of a run. Defaults to True. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(enable_autolog_hparams=True) + + # disable logging hyperparams + trainer = Trainer(enable_autolog_hparams=False) + +With the parameter set to false, you can add custom code to log hyperparameters. + +.. code-block:: python + + model = LitModel() + trainer = Trainer(enable_autolog_hparams=False) + for logger in trainer.loggers: + if isinstance(logger, lightning.pytorch.loggers.CSVLogger): + logger.log_hyperparams(hparams_dict_1) + else: + logger.log_hyperparams(hparams_dict_2) + +You can also use `self.logger.log_hyperparams(...)` inside `LightningModule` to log. + + enable_checkpointing ^^^^^^^^^^^^^^^^^^^^ @@ -443,6 +472,40 @@ See :doc:`Saving and Loading Checkpoints <../common/checkpointing>` for how to c # Add your callback to the callbacks list trainer = Trainer(callbacks=[checkpoint_callback]) + +enable_model_summary +^^^^^^^^^^^^^^^^^^^^ + +Whether to enable or disable the model summarization. Defaults to True. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(enable_model_summary=True) + + # disable summarization + trainer = Trainer(enable_model_summary=False) + + # enable custom summarization + from lightning.pytorch.callbacks import ModelSummary + + trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)]) + + +enable_progress_bar +^^^^^^^^^^^^^^^^^^^ + +Whether to enable or disable the progress bar. Defaults to True. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(enable_progress_bar=True) + + # disable progress bar + trainer = Trainer(enable_progress_bar=False) + + fast_dev_run ^^^^^^^^^^^^ @@ -871,18 +934,6 @@ See the :doc:`profiler documentation <../tuning/profiler>` for more details. # advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler()` trainer = Trainer(profiler="advanced") -enable_progress_bar -^^^^^^^^^^^^^^^^^^^ - -Whether to enable or disable the progress bar. Defaults to True. - -.. testcode:: - - # default used by the Trainer - trainer = Trainer(enable_progress_bar=True) - - # disable progress bar - trainer = Trainer(enable_progress_bar=False) reload_dataloaders_every_n_epochs ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -917,28 +968,6 @@ The pseudocode applies also to the ``val_dataloader``. .. _replace-sampler-ddp: -use_distributed_sampler -^^^^^^^^^^^^^^^^^^^^^^^ - -See :paramref:`lightning.pytorch.trainer.Trainer.params.use_distributed_sampler`. - -.. testcode:: - - # default used by the Trainer - trainer = Trainer(use_distributed_sampler=True) - -By setting to False, you have to add your own distributed sampler: - -.. code-block:: python - - # in your LightningModule or LightningDataModule - def train_dataloader(self): - dataset = ... - # default used by the Trainer - sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) - dataloader = DataLoader(dataset, batch_size=32, sampler=sampler) - return dataloader - strategy ^^^^^^^^ @@ -982,6 +1011,29 @@ Enable synchronization between batchnorm layers across all GPUs. trainer = Trainer(sync_batchnorm=True) +use_distributed_sampler +^^^^^^^^^^^^^^^^^^^^^^^ + +See :paramref:`lightning.pytorch.trainer.Trainer.params.use_distributed_sampler`. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(use_distributed_sampler=True) + +By setting to False, you have to add your own distributed sampler: + +.. code-block:: python + + # in your LightningModule or LightningDataModule + def train_dataloader(self): + dataset = ... + # default used by the Trainer + sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) + dataloader = DataLoader(dataset, batch_size=32, sampler=sampler) + return dataloader + + val_check_interval ^^^^^^^^^^^^^^^^^^ @@ -1059,25 +1111,6 @@ Can specify as float, int, or a time-based duration. total_fit_batches = total_train_batches + total_val_batches -enable_model_summary -^^^^^^^^^^^^^^^^^^^^ - -Whether to enable or disable the model summarization. Defaults to True. - -.. testcode:: - - # default used by the Trainer - trainer = Trainer(enable_model_summary=True) - - # disable summarization - trainer = Trainer(enable_model_summary=False) - - # enable custom summarization - from lightning.pytorch.callbacks import ModelSummary - - trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)]) - - inference_mode ^^^^^^^^^^^^^^ @@ -1109,33 +1142,6 @@ With :func:`torch.inference_mode` disabled, you can enable the grad of your mode trainer = Trainer(inference_mode=False) trainer.validate(model) -enable_autolog_hparams -^^^^^^^^^^^^^^^^^^^^^^ - -Whether to log hyperparameters at the start of a run. Defaults to True. - -.. testcode:: - - # default used by the Trainer - trainer = Trainer(enable_autolog_hparams=True) - - # disable logging hyperparams - trainer = Trainer(enable_autolog_hparams=False) - -With the parameter set to false, you can add custom code to log hyperparameters. - -.. code-block:: python - - model = LitModel() - trainer = Trainer(enable_autolog_hparams=False) - for logger in trainer.loggers: - if isinstance(logger, lightning.pytorch.loggers.CSVLogger): - logger.log_hyperparams(hparams_dict_1) - else: - logger.log_hyperparams(hparams_dict_2) - -You can also use `self.logger.log_hyperparams(...)` inside `LightningModule` to log. - ----- Trainer class API From e766a0af6665c93bbca7634b1e1dc08c56831150 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 12 Sep 2025 08:42:11 +0200 Subject: [PATCH 2/2] reorder --- docs/source-pytorch/common/trainer.rst | 65 +++++++++++++------------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 6eca504798747..d63bdeee1f5cd 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -563,6 +563,39 @@ Gradient clipping value # default used by the Trainer trainer = Trainer(gradient_clip_val=None) + +inference_mode +^^^^^^^^^^^^^^ + +Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` mode during evaluation +(``validate``/``test``/``predict``) + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(inference_mode=True) + + # Use `torch.no_grad` instead + trainer = Trainer(inference_mode=False) + + +With :func:`torch.inference_mode` disabled, you can enable the grad of your model layers if required. + +.. code-block:: python + + class LitModel(LightningModule): + def validation_step(self, batch, batch_idx): + preds = self.layer1(batch) + with torch.enable_grad(): + grad_preds = preds.requires_grad_() + preds2 = self.layer2(grad_preds) + + + model = LitModel() + trainer = Trainer(inference_mode=False) + trainer.validate(model) + + limit_train_batches ^^^^^^^^^^^^^^^^^^^ @@ -1110,38 +1143,6 @@ Can specify as float, int, or a time-based duration. # Total number of batches run total_fit_batches = total_train_batches + total_val_batches - -inference_mode -^^^^^^^^^^^^^^ - -Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` mode during evaluation -(``validate``/``test``/``predict``) - -.. testcode:: - - # default used by the Trainer - trainer = Trainer(inference_mode=True) - - # Use `torch.no_grad` instead - trainer = Trainer(inference_mode=False) - - -With :func:`torch.inference_mode` disabled, you can enable the grad of your model layers if required. - -.. code-block:: python - - class LitModel(LightningModule): - def validation_step(self, batch, batch_idx): - preds = self.layer1(batch) - with torch.enable_grad(): - grad_preds = preds.requires_grad_() - preds2 = self.layer2(grad_preds) - - - model = LitModel() - trainer = Trainer(inference_mode=False) - trainer.validate(model) - ----- Trainer class API