Callback Checkpointing never called #13800
-
I am trying to save my Callback state along with my model at the end of each epoch. I put print statement/breakpoint at the state_dict and load_state_dict method, and they are never called during my training. And the Counter state is not recovered. I tried to use checkpointing both with a callback ModelCheckpoint in the callback list of the trainer and with enable_checkpointing set to True (with or without the ModelCheckpoint) without it working. For reference, other states are well loaded (for example the right epoch, the model weights etc). Here is a minimum working example:
import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn import Parameter
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
class Dummy_dataset(Dataset):
def __init__(self, length):
self.length = length
def __len__(self):
return self.length
def __getitem__(self, item):
return 1
class Dummy_Model(pl.LightningModule):
def __init__(self):
super(Dummy_Model, self).__init__()
self.cpt = Parameter(torch.tensor(0.), requires_grad=False)
def configure_optimizers(self):
return Adam([
{'params': self.cpt},
])
def forward(self, data):
return data
def training_step(self, data, batch_idx):
self.cpt += batch_idx
print(f'\n{self.current_epoch=} {self.cpt=}')
return None
def validation_step(self, data, batch_idx):
pass
def test_step(self, data, batch_idx):
pass
class Counter(Callback):
def __init__(self):
self.state = {"epochs": 0, "batches": 0}
def on_train_epoch_end(self, *args, **kwargs):
self.state['epochs'] += 1
print(self.state)
def on_train_batch_end(self, *args, **kwargs):
self.state['batches'] += 1
def load_state_dict(self, state_dict):
print('load_state_dict')
self.state.update(state_dict)
def state_dict(self):
print('state_dict')
return self.state.copy()
train_dataloader = DataLoader(Dummy_dataset(length=10), batch_size=2)
valid_dataloader = DataLoader(Dummy_dataset(length=10), batch_size=2)
model = Dummy_Model()
checkpoint_callback = ModelCheckpoint(
dirpath='./',
monitor=None,
verbose=True,
save_last=True,
every_n_epochs=1,
)
loading = False
trainer = Trainer(
max_epochs=300 if loading else 50,
callbacks=[Counter(), checkpoint_callback],
enable_checkpointing=True,
)
trainer.fit(
model=model,
train_dataloader=train_dataloader,
val_dataloaders=valid_dataloader,
ckpt_path='last.ckpt' if loading else None,
) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
Your |
Beta Was this translation helpful? Give feedback.
-
what's your lightning version?? |
Beta Was this translation helpful? Give feedback.
what's your lightning version??