Skip to content

Commit 5fa32d9

Browse files
authored
Ignore parameters causing ValueError when dumping to YAML (#19804)
1 parent 4f96c83 commit 5fa32d9

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6060

6161
- Fixed an issue with the LightningCLI not being able to set the `ModelCheckpoint(save_last=...)` argument ([#19808](https://github.com/Lightning-AI/pytorch-lightning/pull/19808))
6262

63+
- Fixed an issue causing ValueError for certain object such as TorchMetrics when dumping hyperparameters to YAML ([#19804](https://github.com/Lightning-AI/pytorch-lightning/pull/19804))
64+
6365

6466
## [2.2.2] - 2024-04-11
6567

src/lightning/pytorch/core/saving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], us
359359
try:
360360
v = v.name if isinstance(v, Enum) else v
361361
yaml.dump(v)
362-
except TypeError:
362+
except (TypeError, ValueError):
363363
warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
364364
hparams[k] = type(v).__name__
365365
else:

tests/tests_pytorch/models/test_hparams.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def test_hparams_pickle_warning(tmp_path):
552552
trainer.fit(model)
553553

554554

555-
def test_hparams_save_yaml(tmp_path):
555+
def test_save_hparams_to_yaml(tmp_path):
556556
class Options(str, Enum):
557557
option1name = "option1val"
558558
option2name = "option2val"
@@ -590,6 +590,14 @@ def _compare_params(loaded_params, default_params: dict):
590590
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
591591

592592

593+
def test_save_hparams_to_yaml_warning(tmp_path):
594+
"""Test that we warn about unserializable parameters that need to be dropped."""
595+
path_yaml = tmp_path / "hparams.yaml"
596+
hparams = {"torch_type": torch.float32}
597+
with pytest.warns(UserWarning, match="Skipping 'torch_type' parameter"):
598+
save_hparams_to_yaml(path_yaml, hparams)
599+
600+
593601
class NoArgsSubClassBoringModel(CustomBoringModel):
594602
def __init__(self):
595603
super().__init__()

0 commit comments

Comments
 (0)