Skip to content

Using trainer.fit(model, datamodule, ckpt_path=path) on compressed model #12987

Discussion options

You must be logged in to vote

this is not possible with fit(..., ckpt_path=...) since this is partial resume and will load the model weights too. For you use-case you can do this maybe:

class OptimizerReload(Callback):
    def __init__(self, ckpt_path):
        self.ckpt_path = ckpt_path

    def on_train_start(self, trainer, pl_module):
        ckpt = torch.load(self.ckpt_path)
        trainer.strategy.load_optimizer_state_dict(ckpt)

and, pass it to Trainer

cb = OptimizerReload(ckpt_path)
trainer = Trainer(..., callbacks=[cb])
trainer.fit(model)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@shenoynikhil98
Comment options

Answer selected by shenoynikhil98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment