Skip to content
13 changes: 13 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [unreleased] - YYYY-MM-DD

### Added

### Changed

### Removed

### Fixed

- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))


## [2.5.0] - 2024-12-19

### Added
Expand Down
9 changes: 2 additions & 7 deletions src/lightning/pytorch/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,10 @@ def __init__(self, log_dir: str) -> None:
self.hparams: dict[str, Any] = {}

def log_hparams(self, params: dict[str, Any]) -> None:
"""Record hparams."""
"""Record hparams and save into files."""
self.hparams.update(params)

@override
def save(self) -> None:
"""Save recorded hparams and metrics into files."""
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
save_hparams_to_yaml(hparams_file, self.hparams)
return super().save()


class CSVLogger(Logger, FabricCSVLogger):
Expand Down Expand Up @@ -144,7 +139,7 @@ def save_dir(self) -> str:

@override
@rank_zero_only
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
def log_hyperparams(self, params: Optional[Union[dict[str, Any], Namespace]] = None) -> None:
params = _convert_params(params)
self.experiment.log_hparams(params)

Expand Down
1 change: 0 additions & 1 deletion tests/tests_pytorch/loggers/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def test_log_hyperparams(tmp_path):
"layer": torch.nn.BatchNorm1d,
}
logger.log_hyperparams(hparams)
logger.save()

path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE)
params = load_hparams_from_yaml(path_yaml)
Expand Down
Loading