@@ -413,6 +413,35 @@ Number of devices to train on (``int``), which devices to train on (``list`` or
413
413
# Training with GPU Accelerator using total number of gpus available on the system
414
414
Trainer(accelerator = " gpu" )
415
415
416
+
417
+ enable_autolog_hparams
418
+ ^^^^^^^^^^^^^^^^^^^^^^
419
+
420
+ Whether to log hyperparameters at the start of a run. Defaults to True.
421
+
422
+ .. testcode ::
423
+
424
+ # default used by the Trainer
425
+ trainer = Trainer(enable_autolog_hparams=True)
426
+
427
+ # disable logging hyperparams
428
+ trainer = Trainer(enable_autolog_hparams=False)
429
+
430
+ With the parameter set to false, you can add custom code to log hyperparameters.
431
+
432
+ .. code-block :: python
433
+
434
+ model = LitModel()
435
+ trainer = Trainer(enable_autolog_hparams = False )
436
+ for logger in trainer.loggers:
437
+ if isinstance (logger, lightning.pytorch.loggers.CSVLogger):
438
+ logger.log_hyperparams(hparams_dict_1)
439
+ else :
440
+ logger.log_hyperparams(hparams_dict_2)
441
+
442
+ You can also use `self.logger.log_hyperparams(...) ` inside `LightningModule ` to log.
443
+
444
+
416
445
enable_checkpointing
417
446
^^^^^^^^^^^^^^^^^^^^
418
447
@@ -443,6 +472,40 @@ See :doc:`Saving and Loading Checkpoints <../common/checkpointing>` for how to c
443
472
# Add your callback to the callbacks list
444
473
trainer = Trainer(callbacks=[checkpoint_callback])
445
474
475
+
476
+ enable_model_summary
477
+ ^^^^^^^^^^^^^^^^^^^^
478
+
479
+ Whether to enable or disable the model summarization. Defaults to True.
480
+
481
+ .. testcode ::
482
+
483
+ # default used by the Trainer
484
+ trainer = Trainer(enable_model_summary=True)
485
+
486
+ # disable summarization
487
+ trainer = Trainer(enable_model_summary=False)
488
+
489
+ # enable custom summarization
490
+ from lightning.pytorch.callbacks import ModelSummary
491
+
492
+ trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])
493
+
494
+
495
+ enable_progress_bar
496
+ ^^^^^^^^^^^^^^^^^^^
497
+
498
+ Whether to enable or disable the progress bar. Defaults to True.
499
+
500
+ .. testcode ::
501
+
502
+ # default used by the Trainer
503
+ trainer = Trainer(enable_progress_bar=True)
504
+
505
+ # disable progress bar
506
+ trainer = Trainer(enable_progress_bar=False)
507
+
508
+
446
509
fast_dev_run
447
510
^^^^^^^^^^^^
448
511
@@ -871,18 +934,6 @@ See the :doc:`profiler documentation <../tuning/profiler>` for more details.
871
934
# advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler() `
872
935
trainer = Trainer(profiler="advanced")
873
936
874
- enable_progress_bar
875
- ^^^^^^^^^^^^^^^^^^^
876
-
877
- Whether to enable or disable the progress bar. Defaults to True.
878
-
879
- .. testcode ::
880
-
881
- # default used by the Trainer
882
- trainer = Trainer(enable_progress_bar=True)
883
-
884
- # disable progress bar
885
- trainer = Trainer(enable_progress_bar=False)
886
937
887
938
reload_dataloaders_every_n_epochs
888
939
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -917,28 +968,6 @@ The pseudocode applies also to the ``val_dataloader``.
917
968
918
969
.. _replace-sampler-ddp :
919
970
920
- use_distributed_sampler
921
- ^^^^^^^^^^^^^^^^^^^^^^^
922
-
923
- See :paramref: `lightning.pytorch.trainer.Trainer.params.use_distributed_sampler `.
924
-
925
- .. testcode ::
926
-
927
- # default used by the Trainer
928
- trainer = Trainer(use_distributed_sampler=True)
929
-
930
- By setting to False, you have to add your own distributed sampler:
931
-
932
- .. code-block :: python
933
-
934
- # in your LightningModule or LightningDataModule
935
- def train_dataloader (self ):
936
- dataset = ...
937
- # default used by the Trainer
938
- sampler = torch.utils.data.DistributedSampler(dataset, shuffle = True )
939
- dataloader = DataLoader(dataset, batch_size = 32 , sampler = sampler)
940
- return dataloader
941
-
942
971
943
972
strategy
944
973
^^^^^^^^
@@ -982,6 +1011,29 @@ Enable synchronization between batchnorm layers across all GPUs.
982
1011
trainer = Trainer(sync_batchnorm=True)
983
1012
984
1013
1014
+ use_distributed_sampler
1015
+ ^^^^^^^^^^^^^^^^^^^^^^^
1016
+
1017
+ See :paramref: `lightning.pytorch.trainer.Trainer.params.use_distributed_sampler `.
1018
+
1019
+ .. testcode ::
1020
+
1021
+ # default used by the Trainer
1022
+ trainer = Trainer(use_distributed_sampler=True)
1023
+
1024
+ By setting to False, you have to add your own distributed sampler:
1025
+
1026
+ .. code-block :: python
1027
+
1028
+ # in your LightningModule or LightningDataModule
1029
+ def train_dataloader (self ):
1030
+ dataset = ...
1031
+ # default used by the Trainer
1032
+ sampler = torch.utils.data.DistributedSampler(dataset, shuffle = True )
1033
+ dataloader = DataLoader(dataset, batch_size = 32 , sampler = sampler)
1034
+ return dataloader
1035
+
1036
+
985
1037
val_check_interval
986
1038
^^^^^^^^^^^^^^^^^^
987
1039
@@ -1059,25 +1111,6 @@ Can specify as float, int, or a time-based duration.
1059
1111
total_fit_batches = total_train_batches + total_val_batches
1060
1112
1061
1113
1062
- enable_model_summary
1063
- ^^^^^^^^^^^^^^^^^^^^
1064
-
1065
- Whether to enable or disable the model summarization. Defaults to True.
1066
-
1067
- .. testcode ::
1068
-
1069
- # default used by the Trainer
1070
- trainer = Trainer(enable_model_summary=True)
1071
-
1072
- # disable summarization
1073
- trainer = Trainer(enable_model_summary=False)
1074
-
1075
- # enable custom summarization
1076
- from lightning.pytorch.callbacks import ModelSummary
1077
-
1078
- trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])
1079
-
1080
-
1081
1114
inference_mode
1082
1115
^^^^^^^^^^^^^^
1083
1116
@@ -1109,33 +1142,6 @@ With :func:`torch.inference_mode` disabled, you can enable the grad of your mode
1109
1142
trainer = Trainer(inference_mode = False )
1110
1143
trainer.validate(model)
1111
1144
1112
- enable_autolog_hparams
1113
- ^^^^^^^^^^^^^^^^^^^^^^
1114
-
1115
- Whether to log hyperparameters at the start of a run. Defaults to True.
1116
-
1117
- .. testcode ::
1118
-
1119
- # default used by the Trainer
1120
- trainer = Trainer(enable_autolog_hparams=True)
1121
-
1122
- # disable logging hyperparams
1123
- trainer = Trainer(enable_autolog_hparams=False)
1124
-
1125
- With the parameter set to false, you can add custom code to log hyperparameters.
1126
-
1127
- .. code-block :: python
1128
-
1129
- model = LitModel()
1130
- trainer = Trainer(enable_autolog_hparams = False )
1131
- for logger in trainer.loggers:
1132
- if isinstance (logger, lightning.pytorch.loggers.CSVLogger):
1133
- logger.log_hyperparams(hparams_dict_1)
1134
- else :
1135
- logger.log_hyperparams(hparams_dict_2)
1136
-
1137
- You can also use `self.logger.log_hyperparams(...) ` inside `LightningModule ` to log.
1138
-
1139
1145
-----
1140
1146
1141
1147
Trainer class API
0 commit comments