-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
Howdy!
I think I found a small bug: when calling self.save_hyperparameters()
on a @dataclasses.dataclass
decorated class that has fields with init=False
, the call to self.save_hyperparameters()
crashes because it tries to use the non-init fields as init-fields.
This is due to
pytorch-lightning/src/lightning/pytorch/utilities/parsing.py
Lines 169 to 170 in 791753b
elif is_dataclass(obj): | |
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} |
not filtering out fields with init=False
when gathering init arguments. I believe this code should read:
elif is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj) if f.init}
Using ignore
, or explicitly passing the list of hparams to save doesn't fix it, because the init_args
object above is built before the *args
or ignore
arguments of save_hyperparameters()
are considered. The fields must be initialized for save_hyperparameters()
to succeed, but this is not always possible (e.g. training timestamps, internals that are created later, etc...).
Not listing those attributes as fields could be a workaround, but means forgoing type-checking because the annotation is enough for @dataclasses.dataclass
to pick up the attribute.
Cheers!
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
import dataclasses
import lightning.pytorch as L
@dataclasses.dataclass
class Module(L.LightningModule):
param: float
not_a_param: float = dataclasses.field(init=False)
def __post_init__(self):
self.save_hyperparameters()
Module(param=1e-3)
Error messages and logs
AttributeError Traceback (most recent call last)
Cell In[10], line 10
7 self.save_hyperparameters()
8 self.not_a_param = 0.0
---> 10 Module(param=1e-3)
File <string>:4, in __init__(self, param)
Cell In[10], line 7, in Module.__post_init__(self)
6 def __post_init__(self):
----> 7 self.save_hyperparameters()
8 self.not_a_param = 0.0
File ~/.cache/uv/archive-v0/MfX1dSMSoPsfklX_8DY25/lib/python3.12/site-packages/lightning/pytorch/core/mixins/hparams_mixin.py:131, in HyperparametersMixin.save_hyperparameters(self, ignore, frame, logger, *args)
129 if current_frame:
130 frame = current_frame.f_back
--> 131 save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams)
File ~/.cache/uv/archive-v0/MfX1dSMSoPsfklX_8DY25/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:170, in save_hyperparameters(obj, ignore, frame, given_hparams, *args)
168 init_args = given_hparams
169 elif is_dataclass(obj):
--> 170 init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
171 else:
172 init_args = {}
File ~/.cache/uv/archive-v0/MfX1dSMSoPsfklX_8DY25/lib/python3.12/site-packages/torch/nn/modules/module.py:1729, in Module.__getattr__(self, name)
1727 if name in modules:
1728 return modules[name]
-> 1729 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'Module' object has no attribute 'not_a_param'
Environment
Current environment
Temporary environment created to run the code above using:
uv run --python 3.12 --with lightning --with ipython ipython
#- PyTorch Lightning Version (e.g., 2.5.0): '2.5.1.post0'
#- PyTorch Version (e.g., 2.5): 2.4.1+cu121
#- Python version (e.g., 3.12): 3.12
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 12.1
#- GPU models and configuration: N/A
#- How you installed Lightning(`conda`, `pip`, source): `uv`