Skip to content

Commit aac4824

Browse files
committed
reorder args
1 parent bd1f3fd commit aac4824

File tree

1 file changed

+86
-80
lines changed

1 file changed

+86
-80
lines changed

docs/source-pytorch/common/trainer.rst

Lines changed: 86 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,35 @@ Number of devices to train on (``int``), which devices to train on (``list`` or
413413
# Training with GPU Accelerator using total number of gpus available on the system
414414
Trainer(accelerator="gpu")
415415
416+
417+
enable_autolog_hparams
418+
^^^^^^^^^^^^^^^^^^^^^^
419+
420+
Whether to log hyperparameters at the start of a run. Defaults to True.
421+
422+
.. testcode::
423+
424+
# default used by the Trainer
425+
trainer = Trainer(enable_autolog_hparams=True)
426+
427+
# disable logging hyperparams
428+
trainer = Trainer(enable_autolog_hparams=False)
429+
430+
With the parameter set to false, you can add custom code to log hyperparameters.
431+
432+
.. code-block:: python
433+
434+
model = LitModel()
435+
trainer = Trainer(enable_autolog_hparams=False)
436+
for logger in trainer.loggers:
437+
if isinstance(logger, lightning.pytorch.loggers.CSVLogger):
438+
logger.log_hyperparams(hparams_dict_1)
439+
else:
440+
logger.log_hyperparams(hparams_dict_2)
441+
442+
You can also use `self.logger.log_hyperparams(...)` inside `LightningModule` to log.
443+
444+
416445
enable_checkpointing
417446
^^^^^^^^^^^^^^^^^^^^
418447

@@ -443,6 +472,40 @@ See :doc:`Saving and Loading Checkpoints <../common/checkpointing>` for how to c
443472
# Add your callback to the callbacks list
444473
trainer = Trainer(callbacks=[checkpoint_callback])
445474

475+
476+
enable_model_summary
477+
^^^^^^^^^^^^^^^^^^^^
478+
479+
Whether to enable or disable the model summarization. Defaults to True.
480+
481+
.. testcode::
482+
483+
# default used by the Trainer
484+
trainer = Trainer(enable_model_summary=True)
485+
486+
# disable summarization
487+
trainer = Trainer(enable_model_summary=False)
488+
489+
# enable custom summarization
490+
from lightning.pytorch.callbacks import ModelSummary
491+
492+
trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])
493+
494+
495+
enable_progress_bar
496+
^^^^^^^^^^^^^^^^^^^
497+
498+
Whether to enable or disable the progress bar. Defaults to True.
499+
500+
.. testcode::
501+
502+
# default used by the Trainer
503+
trainer = Trainer(enable_progress_bar=True)
504+
505+
# disable progress bar
506+
trainer = Trainer(enable_progress_bar=False)
507+
508+
446509
fast_dev_run
447510
^^^^^^^^^^^^
448511

@@ -871,18 +934,6 @@ See the :doc:`profiler documentation <../tuning/profiler>` for more details.
871934
# advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler()`
872935
trainer = Trainer(profiler="advanced")
873936

874-
enable_progress_bar
875-
^^^^^^^^^^^^^^^^^^^
876-
877-
Whether to enable or disable the progress bar. Defaults to True.
878-
879-
.. testcode::
880-
881-
# default used by the Trainer
882-
trainer = Trainer(enable_progress_bar=True)
883-
884-
# disable progress bar
885-
trainer = Trainer(enable_progress_bar=False)
886937

887938
reload_dataloaders_every_n_epochs
888939
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -917,28 +968,6 @@ The pseudocode applies also to the ``val_dataloader``.
917968

918969
.. _replace-sampler-ddp:
919970

920-
use_distributed_sampler
921-
^^^^^^^^^^^^^^^^^^^^^^^
922-
923-
See :paramref:`lightning.pytorch.trainer.Trainer.params.use_distributed_sampler`.
924-
925-
.. testcode::
926-
927-
# default used by the Trainer
928-
trainer = Trainer(use_distributed_sampler=True)
929-
930-
By setting to False, you have to add your own distributed sampler:
931-
932-
.. code-block:: python
933-
934-
# in your LightningModule or LightningDataModule
935-
def train_dataloader(self):
936-
dataset = ...
937-
# default used by the Trainer
938-
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
939-
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
940-
return dataloader
941-
942971

943972
strategy
944973
^^^^^^^^
@@ -982,6 +1011,29 @@ Enable synchronization between batchnorm layers across all GPUs.
9821011
trainer = Trainer(sync_batchnorm=True)
9831012

9841013

1014+
use_distributed_sampler
1015+
^^^^^^^^^^^^^^^^^^^^^^^
1016+
1017+
See :paramref:`lightning.pytorch.trainer.Trainer.params.use_distributed_sampler`.
1018+
1019+
.. testcode::
1020+
1021+
# default used by the Trainer
1022+
trainer = Trainer(use_distributed_sampler=True)
1023+
1024+
By setting to False, you have to add your own distributed sampler:
1025+
1026+
.. code-block:: python
1027+
1028+
# in your LightningModule or LightningDataModule
1029+
def train_dataloader(self):
1030+
dataset = ...
1031+
# default used by the Trainer
1032+
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
1033+
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
1034+
return dataloader
1035+
1036+
9851037
val_check_interval
9861038
^^^^^^^^^^^^^^^^^^
9871039

@@ -1059,25 +1111,6 @@ Can specify as float, int, or a time-based duration.
10591111
total_fit_batches = total_train_batches + total_val_batches
10601112
10611113
1062-
enable_model_summary
1063-
^^^^^^^^^^^^^^^^^^^^
1064-
1065-
Whether to enable or disable the model summarization. Defaults to True.
1066-
1067-
.. testcode::
1068-
1069-
# default used by the Trainer
1070-
trainer = Trainer(enable_model_summary=True)
1071-
1072-
# disable summarization
1073-
trainer = Trainer(enable_model_summary=False)
1074-
1075-
# enable custom summarization
1076-
from lightning.pytorch.callbacks import ModelSummary
1077-
1078-
trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])
1079-
1080-
10811114
inference_mode
10821115
^^^^^^^^^^^^^^
10831116

@@ -1109,33 +1142,6 @@ With :func:`torch.inference_mode` disabled, you can enable the grad of your mode
11091142
trainer = Trainer(inference_mode=False)
11101143
trainer.validate(model)
11111144
1112-
enable_autolog_hparams
1113-
^^^^^^^^^^^^^^^^^^^^^^
1114-
1115-
Whether to log hyperparameters at the start of a run. Defaults to True.
1116-
1117-
.. testcode::
1118-
1119-
# default used by the Trainer
1120-
trainer = Trainer(enable_autolog_hparams=True)
1121-
1122-
# disable logging hyperparams
1123-
trainer = Trainer(enable_autolog_hparams=False)
1124-
1125-
With the parameter set to false, you can add custom code to log hyperparameters.
1126-
1127-
.. code-block:: python
1128-
1129-
model = LitModel()
1130-
trainer = Trainer(enable_autolog_hparams=False)
1131-
for logger in trainer.loggers:
1132-
if isinstance(logger, lightning.pytorch.loggers.CSVLogger):
1133-
logger.log_hyperparams(hparams_dict_1)
1134-
else:
1135-
logger.log_hyperparams(hparams_dict_2)
1136-
1137-
You can also use `self.logger.log_hyperparams(...)` inside `LightningModule` to log.
1138-
11391145
-----
11401146

11411147
Trainer class API

0 commit comments

Comments
 (0)