diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 01bbc0ff03fb0..a2e9ded0aeb2c 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960)) +- Fixed `save_hyperparameters` crashing with `dataclasses` using `init=False` fields ([#21051](https://github.com/Lightning-AI/pytorch-lightning/pull/21051)) + + --- ## [2.5.2] - 2025-06-20 diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 16eef555291bd..829cc7a994b93 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -167,7 +167,8 @@ def save_hyperparameters( if given_hparams is not None: init_args = given_hparams elif is_dataclass(obj): - init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} + obj_fields = fields(obj) + init_args = {f.name: getattr(obj, f.name) for f in obj_fields if f.init} else: init_args = {} diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index f14d62b6befb4..575bcadadc404 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -17,7 +17,7 @@ import pickle import sys from argparse import Namespace -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from unittest import mock @@ -881,6 +881,31 @@ def test_dataclass_lightning_module(tmp_path): assert model.hparams == {"mandatory": 33, "optional": "cocofruit"} +def test_dataclass_with_init_false_fields(): + """Test that save_hyperparameters() filters out fields with init=False and issues a warning.""" + + @dataclass + class DataClassWithInitFalseFieldsModel(BoringModel): + mandatory: int + optional: str = "optional" + non_init_field: int = field(default=999, init=False) + another_non_init: str = field(default="not_in_init", init=False) + + def __post_init__(self): + super().__init__() + self.save_hyperparameters() + + model = DataClassWithInitFalseFieldsModel(33, optional="cocofruit") + + expected_hparams = {"mandatory": 33, "optional": "cocofruit"} + assert model.hparams == expected_hparams + + assert model.non_init_field == 999 + assert model.another_non_init == "not_in_init" + assert "non_init_field" not in model.hparams + assert "another_non_init" not in model.hparams + + class NoHparamsModel(BoringModel): """Tests a model without hparams."""