diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index ef0f3dc73c9e0..8bc8e45989f77 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,6 +4,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased] - YYYY-MM-DD + +### Added + +### Changed + +### Removed + +### Fixed + +- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594)) + + ## [2.5.0] - 2024-12-19 ### Added diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py index 8606264dc3cdb..5ad7353310af4 100644 --- a/src/lightning/pytorch/loggers/csv_logs.py +++ b/src/lightning/pytorch/loggers/csv_logs.py @@ -55,15 +55,10 @@ def __init__(self, log_dir: str) -> None: self.hparams: dict[str, Any] = {} def log_hparams(self, params: dict[str, Any]) -> None: - """Record hparams.""" + """Record hparams and save into files.""" self.hparams.update(params) - - @override - def save(self) -> None: - """Save recorded hparams and metrics into files.""" hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) save_hparams_to_yaml(hparams_file, self.hparams) - return super().save() class CSVLogger(Logger, FabricCSVLogger): @@ -144,7 +139,7 @@ def save_dir(self) -> str: @override @rank_zero_only - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Optional[Union[dict[str, Any], Namespace]] = None) -> None: params = _convert_params(params) self.experiment.log_hparams(params) diff --git a/tests/tests_pytorch/loggers/test_csv.py b/tests/tests_pytorch/loggers/test_csv.py index 1b09302ffb74a..c131d03d38245 100644 --- a/tests/tests_pytorch/loggers/test_csv.py +++ b/tests/tests_pytorch/loggers/test_csv.py @@ -75,7 +75,6 @@ def test_named_version(tmp_path): logger = CSVLogger(save_dir=tmp_path, name=exp_name, version=expected_version) logger.log_hyperparams({"a": 1, "b": 2}) - logger.save() assert logger.version == expected_version assert os.listdir(tmp_path / exp_name) == [expected_version] assert os.listdir(tmp_path / exp_name / expected_version) @@ -85,7 +84,7 @@ def test_named_version(tmp_path): def test_no_name(tmp_path, name): """Verify that None or empty name works.""" logger = CSVLogger(save_dir=tmp_path, name=name) - logger.save() + logger.log_hyperparams() assert os.path.normpath(logger.root_dir) == str(tmp_path) # use os.path.normpath to handle trailing / assert os.listdir(tmp_path / "version_0") @@ -116,7 +115,6 @@ def test_log_hyperparams(tmp_path): "layer": torch.nn.BatchNorm1d, } logger.log_hyperparams(hparams) - logger.save() path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE) params = load_hparams_from_yaml(path_yaml)