Skip to content

Commit b7ec502

Browse files
Fix save_hyperparameters not crashing on dataclass with init=False (#21051)
* fix implementation * add testing * changelog --------- Co-authored-by: Quentin Soubeyran <[email protected]>
1 parent 9c48699 commit b7ec502

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4141
- Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960))
4242

4343

44+
- Fixed `save_hyperparameters` crashing with `dataclasses` using `init=False` fields ([#21051](https://github.com/Lightning-AI/pytorch-lightning/pull/21051))
45+
46+
4447
---
4548

4649
## [2.5.2] - 2025-06-20

src/lightning/pytorch/utilities/parsing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def save_hyperparameters(
167167
if given_hparams is not None:
168168
init_args = given_hparams
169169
elif is_dataclass(obj):
170-
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
170+
obj_fields = fields(obj)
171+
init_args = {f.name: getattr(obj, f.name) for f in obj_fields if f.init}
171172
else:
172173
init_args = {}
173174

tests/tests_pytorch/models/test_hparams.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pickle
1818
import sys
1919
from argparse import Namespace
20-
from dataclasses import dataclass
20+
from dataclasses import dataclass, field
2121
from enum import Enum
2222
from unittest import mock
2323

@@ -881,6 +881,31 @@ def test_dataclass_lightning_module(tmp_path):
881881
assert model.hparams == {"mandatory": 33, "optional": "cocofruit"}
882882

883883

884+
def test_dataclass_with_init_false_fields():
885+
"""Test that save_hyperparameters() filters out fields with init=False and issues a warning."""
886+
887+
@dataclass
888+
class DataClassWithInitFalseFieldsModel(BoringModel):
889+
mandatory: int
890+
optional: str = "optional"
891+
non_init_field: int = field(default=999, init=False)
892+
another_non_init: str = field(default="not_in_init", init=False)
893+
894+
def __post_init__(self):
895+
super().__init__()
896+
self.save_hyperparameters()
897+
898+
model = DataClassWithInitFalseFieldsModel(33, optional="cocofruit")
899+
900+
expected_hparams = {"mandatory": 33, "optional": "cocofruit"}
901+
assert model.hparams == expected_hparams
902+
903+
assert model.non_init_field == 999
904+
assert model.another_non_init == "not_in_init"
905+
assert "non_init_field" not in model.hparams
906+
assert "another_non_init" not in model.hparams
907+
908+
884909
class NoHparamsModel(BoringModel):
885910
"""Tests a model without hparams."""
886911

0 commit comments

Comments
 (0)