Skip to content

Commit 72f4706

Browse files
committed
Merge branch 'master' into feature/option-disable-loghparams
2 parents c0256a4 + 1f5add3 commit 72f4706

File tree

5 files changed

+11
-14
lines changed

5 files changed

+11
-14
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
### Fixed
1818

19+
- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))
20+
21+
1922
## [2.5.0] - 2024-12-19
2023

2124
### Added

src/lightning/pytorch/loggers/csv_logs.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,10 @@ def __init__(self, log_dir: str) -> None:
5555
self.hparams: dict[str, Any] = {}
5656

5757
def log_hparams(self, params: dict[str, Any]) -> None:
58-
"""Record hparams."""
58+
"""Record hparams and save into files."""
5959
self.hparams.update(params)
60-
61-
@override
62-
def save(self) -> None:
63-
"""Save recorded hparams and metrics into files."""
6460
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
6561
save_hparams_to_yaml(hparams_file, self.hparams)
66-
return super().save()
6762

6863

6964
class CSVLogger(Logger, FabricCSVLogger):
@@ -144,7 +139,7 @@ def save_dir(self) -> str:
144139

145140
@override
146141
@rank_zero_only
147-
def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
142+
def log_hyperparams(self, params: Optional[Union[dict[str, Any], Namespace]] = None) -> None:
148143
params = _convert_params(params)
149144
self.experiment.log_hparams(params)
150145

tests/tests_fabric/utilities/test_seed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import random
3+
import warnings
34
from unittest import mock
45
from unittest.mock import Mock
56

@@ -30,9 +31,9 @@ def test_seed_stays_same_with_multiple_seed_everything_calls():
3031
seed_everything()
3132
initial_seed = os.environ.get("PL_GLOBAL_SEED")
3233

33-
with pytest.warns(None) as record:
34+
with warnings.catch_warnings():
35+
warnings.simplefilter("error")
3436
seed_everything()
35-
assert not record # does not warn
3637
seed = os.environ.get("PL_GLOBAL_SEED")
3738

3839
assert initial_seed == seed

tests/tests_pytorch/callbacks/test_lr_monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,10 +548,10 @@ def finetune_function(self, pl_module, epoch: int, optimizer):
548548
"""Called when the epoch begins."""
549549
if epoch == 1 and isinstance(optimizer, torch.optim.SGD):
550550
self.unfreeze_and_add_param_group(pl_module.backbone[0], optimizer, lr=0.1)
551-
if epoch == 2 and isinstance(optimizer, torch.optim.Adam):
551+
if epoch == 2 and type(optimizer) is torch.optim.Adam:
552552
self.unfreeze_and_add_param_group(pl_module.layer, optimizer, lr=0.1)
553553

554-
if epoch == 3 and isinstance(optimizer, torch.optim.Adam):
554+
if epoch == 3 and type(optimizer) is torch.optim.Adam:
555555
assert len(optimizer.param_groups) == 2
556556
self.unfreeze_and_add_param_group(pl_module.backbone[1], optimizer, lr=0.1)
557557
assert len(optimizer.param_groups) == 3

tests/tests_pytorch/loggers/test_csv.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def test_named_version(tmp_path):
7575

7676
logger = CSVLogger(save_dir=tmp_path, name=exp_name, version=expected_version)
7777
logger.log_hyperparams({"a": 1, "b": 2})
78-
logger.save()
7978
assert logger.version == expected_version
8079
assert os.listdir(tmp_path / exp_name) == [expected_version]
8180
assert os.listdir(tmp_path / exp_name / expected_version)
@@ -85,7 +84,7 @@ def test_named_version(tmp_path):
8584
def test_no_name(tmp_path, name):
8685
"""Verify that None or empty name works."""
8786
logger = CSVLogger(save_dir=tmp_path, name=name)
88-
logger.save()
87+
logger.log_hyperparams()
8988
assert os.path.normpath(logger.root_dir) == str(tmp_path) # use os.path.normpath to handle trailing /
9089
assert os.listdir(tmp_path / "version_0")
9190

@@ -116,7 +115,6 @@ def test_log_hyperparams(tmp_path):
116115
"layer": torch.nn.BatchNorm1d,
117116
}
118117
logger.log_hyperparams(hparams)
119-
logger.save()
120118

121119
path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE)
122120
params = load_hparams_from_yaml(path_yaml)

0 commit comments

Comments
 (0)