Skip to content

LightningModule.save_hyperparameters crash on dataclass with non-init fields #21036

@QuentinSoubeyranAqemia

Description

@QuentinSoubeyranAqemia

Bug description

Howdy!

I think I found a small bug: when calling self.save_hyperparameters() on a @dataclasses.dataclass decorated class that has fields with init=False, the call to self.save_hyperparameters() crashes because it tries to use the non-init fields as init-fields.

This is due to

elif is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}

not filtering out fields with init=False when gathering init arguments. I believe this code should read:

    elif is_dataclass(obj):
        init_args = {f.name: getattr(obj, f.name) for f in fields(obj) if f.init}

Using ignore, or explicitly passing the list of hparams to save doesn't fix it, because the init_args object above is built before the *args or ignore arguments of save_hyperparameters() are considered. The fields must be initialized for save_hyperparameters() to succeed, but this is not always possible (e.g. training timestamps, internals that are created later, etc...).

Not listing those attributes as fields could be a workaround, but means forgoing type-checking because the annotation is enough for @dataclasses.dataclass to pick up the attribute.

Cheers!

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

import dataclasses

import lightning.pytorch as L

@dataclasses.dataclass
class Module(L.LightningModule):
    param: float
    not_a_param: float = dataclasses.field(init=False)

    def __post_init__(self):
        self.save_hyperparameters()

Module(param=1e-3)

Error messages and logs

AttributeError                            Traceback (most recent call last)
Cell In[10], line 10
      7         self.save_hyperparameters()
      8         self.not_a_param = 0.0
---> 10 Module(param=1e-3)

File <string>:4, in __init__(self, param)

Cell In[10], line 7, in Module.__post_init__(self)
      6 def __post_init__(self):
----> 7     self.save_hyperparameters()
      8     self.not_a_param = 0.0

File ~/.cache/uv/archive-v0/MfX1dSMSoPsfklX_8DY25/lib/python3.12/site-packages/lightning/pytorch/core/mixins/hparams_mixin.py:131, in HyperparametersMixin.save_hyperparameters(self, ignore, frame, logger, *args)
    129     if current_frame:
    130         frame = current_frame.f_back
--> 131 save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams)

File ~/.cache/uv/archive-v0/MfX1dSMSoPsfklX_8DY25/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:170, in save_hyperparameters(obj, ignore, frame, given_hparams, *args)
    168     init_args = given_hparams
    169 elif is_dataclass(obj):
--> 170     init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
    171 else:
    172     init_args = {}

File ~/.cache/uv/archive-v0/MfX1dSMSoPsfklX_8DY25/lib/python3.12/site-packages/torch/nn/modules/module.py:1729, in Module.__getattr__(self, name)
   1727     if name in modules:
   1728         return modules[name]
-> 1729 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

AttributeError: 'Module' object has no attribute 'not_a_param'

Environment

Current environment

Temporary environment created to run the code above using:

uv run --python 3.12 --with lightning --with ipython ipython
#- PyTorch Lightning Version (e.g., 2.5.0): '2.5.1.post0'
#- PyTorch Version (e.g., 2.5): 2.4.1+cu121
#- Python version (e.g., 3.12): 3.12
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 12.1
#- GPU models and configuration: N/A
#- How you installed Lightning(`conda`, `pip`, source): `uv`

More info

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions