Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960))


- Fixed `save_hyperparameters` crashing with `dataclasses` using `init=False` ([#21051](https://github.com/Lightning-AI/pytorch-lightning/pull/21051))


---

## [2.5.2] - 2025-06-20
Expand Down
9 changes: 8 additions & 1 deletion src/lightning/pytorch/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,14 @@ def save_hyperparameters(
if given_hparams is not None:
init_args = given_hparams
elif is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
obj_fields = fields(obj)
init_args = {f.name: getattr(obj, f.name) for f in obj_fields if f.init}
if any(not f.init for f in obj_fields):
rank_zero_warn(
"Detected a dataclass with fields with `init=False`. This is not supported by `save_hyperparameters`"
" and will not save those fields in `self.hparams`. Consider removing `init=False` and just"
" re-initialize the attributes in the `__post_init__` method of the dataclass."
)
else:
init_args = {}

Expand Down
28 changes: 27 additions & 1 deletion tests/tests_pytorch/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pickle
import sys
from argparse import Namespace
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from unittest import mock

Expand Down Expand Up @@ -881,6 +881,32 @@ def test_dataclass_lightning_module(tmp_path):
assert model.hparams == {"mandatory": 33, "optional": "cocofruit"}


def test_dataclass_with_init_false_fields():
"""Test that save_hyperparameters() filters out fields with init=False and issues a warning."""

@dataclass
class DataClassWithInitFalseFieldsModel(BoringModel):
mandatory: int
optional: str = "optional"
non_init_field: int = field(default=999, init=False)
another_non_init: str = field(default="not_in_init", init=False)

def __post_init__(self):
super().__init__()
self.save_hyperparameters()

with pytest.warns(UserWarning, match="Detected a dataclass with fields with `init=False`"):
model = DataClassWithInitFalseFieldsModel(33, optional="cocofruit")

expected_hparams = {"mandatory": 33, "optional": "cocofruit"}
assert model.hparams == expected_hparams

assert model.non_init_field == 999
assert model.another_non_init == "not_in_init"
assert "non_init_field" not in model.hparams
assert "another_non_init" not in model.hparams


class NoHparamsModel(BoringModel):
"""Tests a model without hparams."""

Expand Down
Loading