-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
checkpointingRelated to checkpointingRelated to checkpointingfeatureIs an improvement or enhancementIs an improvement or enhancement
Description
Description & Motivation
My goal is to reduce the file size
For example, saving a checkpoint with SONAR model require 6GB of disk space
My solution to the problem above is to replace state_dict
from checkpoint with only trainable dict
class LitModel(LightningModule)
def on_save_checkpoint(self, checkpoint):
checkpoint['state_dict'] = self.get_trainable_state_dict()
def get_trainable_state_dict(self):
state = {}
for name, param in self.named_parameters():
if param.requires_grad:
state[name] = param.data.cpu()
for name, buffer in self.named_buffers():
state[name] = buffer.data.cpu()
return state
However, one thing that I notice is that I require strict=False
in LitModel.load_from_checkpoint(strict=False)
to load the checkpoint. So I assume resuming from checkpoint using trainer.fit(ckpt_path)
would also fail without strict=False
PS: I have not tried it, I am in the middle of training (3.8 out of 13 hours), I don't want to risk it
Pitch
No response
Alternatives
No response
Additional context
No response
Metadata
Metadata
Assignees
Labels
checkpointingRelated to checkpointingRelated to checkpointingfeatureIs an improvement or enhancementIs an improvement or enhancement