diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index c40ad1fcf92e2..339c59771001a 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -1077,6 +1077,32 @@ With :func:`torch.inference_mode` disabled, you can enable the grad of your mode trainer = Trainer(inference_mode=False) trainer.validate(model) +enable_autolog_hparams +^^^^^^^^^^^^^^^^^^^^^^ + +Whether to log hyperparameters at the start of a run. Defaults to True. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(enable_autolog_hparams=True) + + # disable logging hyperparams + trainer = Trainer(enable_autolog_hparams=False) + +With the parameter set to false, you can add custom code to log hyperparameters. + +.. code-block:: python + + model = LitModel() + trainer = Trainer(enable_autolog_hparams=False) + for logger in trainer.loggers: + if isinstance(logger, lightning.pytorch.loggers.CSVLogger): + logger.log_hyperparams(hparams_dict_1) + else: + logger.log_hyperparams(hparams_dict_2) + +You can also use `self.logger.log_hyperparams(...)` inside `LightningModule` to log. ----- diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index eae21264ad5c9..f300b95eea8d1 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593)) + + - Allow LightningCLI to use a customized argument parser class ([#20596](https://github.com/Lightning-AI/pytorch-lightning/pull/20596)) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 0509f28acb07a..8e119e24438db 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -128,6 +128,7 @@ def __init__( sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, + enable_autolog_hparams: bool = True, ) -> None: r"""Customize every aspect of training via flags. @@ -290,6 +291,9 @@ def __init__( Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' + enable_autolog_hparams: Whether to log hyperparameters at the start of a run. + Default: ``True``. + Raises: TypeError: If ``gradient_clip_val`` is not an int or float. @@ -496,6 +500,8 @@ def __init__( num_sanity_val_steps, ) + self.enable_autolog_hparams = enable_autolog_hparams + def fit( self, model: "pl.LightningModule", @@ -962,7 +968,9 @@ def _run( call._call_callback_hooks(self, "on_fit_start") call._call_lightning_module_hook(self, "on_fit_start") - _log_hyperparams(self) + # only log hparams if enabled + if self.enable_autolog_hparams: + _log_hyperparams(self) if self.strategy.restore_checkpoint_after_setup: log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") diff --git a/tests/tests_pytorch/loggers/test_csv.py b/tests/tests_pytorch/loggers/test_csv.py index c131d03d38245..3b1e4dc91e391 100644 --- a/tests/tests_pytorch/loggers/test_csv.py +++ b/tests/tests_pytorch/loggers/test_csv.py @@ -163,7 +163,7 @@ def test_metrics_reset_after_save(tmp_path): @mock.patch( - # Mock the existance check, so we can simulate appending to the metrics file + # Mock the existence check, so we can simulate appending to the metrics file "lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists" ) def test_append_metrics_file(_, tmp_path):