-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
I writed some code in load_state_dict
to synchronize variables from other repos that are not registered in buffer.
My code looks like this:
# in a subclass of `torch.nn.Module`
def on_load_checkpoint(self, checkpoint):
self.sync_buffer_to_badmodule()
print(123)
def load_state_dict(self, state_dict, strict = True, assign = False):
result = super().load_state_dict(state_dict, strict, assign)
self.sync_buffer_to_badmodule()
print(123)
return result
More specifically, I used kmeans from the following link in my torch.nn.Module
object. The repository uses a NamedTuple
to hold clustering results _result
which was not registered in buffer.
I found that _result
in Kmeans Module was not loaded by lightning when I loaded from lightning's ckpt. So I registered my buffer in my Module and kept track of the _result
of Kmeans Module. And then I restore the my buffer to Kmeans Module in load_state_dict
. However, I found that load_state_dict
is not called.
But no matter how hard I try, neither of these functions will be executed when loading from a checkpoint. What am I supposed to do?
If load_state_dict
was not called, why are model parameters and buffers loaded correctly? Does lightning have a mechanism for loading state_dict
independent of torch.nn.Module
?
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
cc @lantiga