From 6c3c38dc2932861726183de456483130e959f90f Mon Sep 17 00:00:00 2001 From: duydl Date: Tue, 18 Feb 2025 16:32:11 +0700 Subject: [PATCH 1/6] Make hyperparam logging optional --- src/lightning/pytorch/trainer/trainer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 0509f28acb07a..9edafa583709a 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -128,6 +128,7 @@ def __init__( sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, + log_hyperparams_enabled: bool = True, ) -> None: r"""Customize every aspect of training via flags. @@ -290,6 +291,9 @@ def __init__( Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' + log_hyperparams_enabled: Whether to log hyperparameters at the start of a run. + Default: ``True``. + Raises: TypeError: If ``gradient_clip_val`` is not an int or float. @@ -496,6 +500,8 @@ def __init__( num_sanity_val_steps, ) + self.log_hyperparams_enabled = log_hyperparams_enabled + def fit( self, model: "pl.LightningModule", @@ -962,7 +968,9 @@ def _run( call._call_callback_hooks(self, "on_fit_start") call._call_lightning_module_hook(self, "on_fit_start") - _log_hyperparams(self) + # only log hparams if enabled + if self.log_hyperparams_enabled: + _log_hyperparams(self) if self.strategy.restore_checkpoint_after_setup: log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") From ba576bf6b87af27ab1a8453f11940fae491c0b4c Mon Sep 17 00:00:00 2001 From: duydl Date: Tue, 18 Feb 2025 19:06:07 +0700 Subject: [PATCH 2/6] Add docs --- docs/source-pytorch/common/trainer.rst | 24 ++++++++++++++++++++++++ src/lightning/pytorch/trainer/trainer.py | 8 ++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index c40ad1fcf92e2..2e7e1eee0b14e 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -1077,6 +1077,30 @@ With :func:`torch.inference_mode` disabled, you can enable the grad of your mode trainer = Trainer(inference_mode=False) trainer.validate(model) +enable_autolog_hparams +^^^^^^^^^^^^^^^^^^^^^^ + +Whether to log hyperparameters at the start of a run. Defaults to True. + +.. testcode:: + + # default used by the Trainer + trainer = Trainer(enable_autolog_hparams=True) + + # disable logging hyperparams + trainer = Trainer(enable_autolog_hparams=False) + +With the parameter set to false, you can add custom code to log hyperparameters. + +.. code-block:: python + + model = LitModel() + trainer = Trainer(enable_autolog_hparams=False) + for logger in trainer.loggers: + if isinstance(logger, lightning.pytorch.loggers.CSVLogger): + logger.log_hyperparams(hparams_dict_1) + else: + logger.log_hyperparams(hparams_dict_2) ----- diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 9edafa583709a..8e119e24438db 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -128,7 +128,7 @@ def __init__( sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, - log_hyperparams_enabled: bool = True, + enable_autolog_hparams: bool = True, ) -> None: r"""Customize every aspect of training via flags. @@ -291,7 +291,7 @@ def __init__( Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' - log_hyperparams_enabled: Whether to log hyperparameters at the start of a run. + enable_autolog_hparams: Whether to log hyperparameters at the start of a run. Default: ``True``. Raises: @@ -500,7 +500,7 @@ def __init__( num_sanity_val_steps, ) - self.log_hyperparams_enabled = log_hyperparams_enabled + self.enable_autolog_hparams = enable_autolog_hparams def fit( self, @@ -969,7 +969,7 @@ def _run( call._call_lightning_module_hook(self, "on_fit_start") # only log hparams if enabled - if self.log_hyperparams_enabled: + if self.enable_autolog_hparams: _log_hyperparams(self) if self.strategy.restore_checkpoint_after_setup: From 43d9d2c5137d08fce55298e166e6bc27f374baae Mon Sep 17 00:00:00 2001 From: duydl Date: Tue, 18 Feb 2025 19:17:11 +0700 Subject: [PATCH 3/6] Modify docs --- docs/source-pytorch/common/trainer.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 2e7e1eee0b14e..fc13ff703b5e2 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -1102,6 +1102,8 @@ With the parameter set to false, you can add custom code to log hyperparameters. else: logger.log_hyperparams(hparams_dict_2) +You can also use `self.logger.log_hyperparams(...)` inside `LightningModule` to log. + ----- Trainer class API From 8170cea0f3b0352d1e5b0aa4c9d76dd241436394 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Feb 2025 12:17:35 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source-pytorch/common/trainer.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index fc13ff703b5e2..339c59771001a 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -1102,7 +1102,7 @@ With the parameter set to false, you can add custom code to log hyperparameters. else: logger.log_hyperparams(hparams_dict_2) -You can also use `self.logger.log_hyperparams(...)` inside `LightningModule` to log. +You can also use `self.logger.log_hyperparams(...)` inside `LightningModule` to log. ----- From 0da2efaaf2d9ac3562b0611af3723ae8b4edf198 Mon Sep 17 00:00:00 2001 From: duydl Date: Wed, 19 Feb 2025 18:46:15 +0700 Subject: [PATCH 5/6] Add to CHANGELOG.md --- src/lightning/pytorch/CHANGELOG.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index ef0f3dc73c9e0..399e0de27749d 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased] - YYYY-MM-DD + +### Added + +- Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593)) + +### Changed + +### Removed + +### Fixed + ## [2.5.0] - 2024-12-19 ### Added From 0f1c4689ba6b2b80612680fecb94af1f559d0e98 Mon Sep 17 00:00:00 2001 From: duydl Date: Wed, 19 Feb 2025 18:46:22 +0700 Subject: [PATCH 6/6] Fix typos --- tests/tests_pytorch/loggers/test_csv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/loggers/test_csv.py b/tests/tests_pytorch/loggers/test_csv.py index 1b09302ffb74a..a0901057cd526 100644 --- a/tests/tests_pytorch/loggers/test_csv.py +++ b/tests/tests_pytorch/loggers/test_csv.py @@ -165,7 +165,7 @@ def test_metrics_reset_after_save(tmp_path): @mock.patch( - # Mock the existance check, so we can simulate appending to the metrics file + # Mock the existence check, so we can simulate appending to the metrics file "lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists" ) def test_append_metrics_file(_, tmp_path):