How do I load a model checkpointed with ddp_spawn? #5728
Unanswered
laughingrice
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am using pytorch_lightning.callbacks.ModelCheckpoint(save_weights_only=True) to checkpoint my trained model
Ignoring the problem for the moment that my checkpoint files are huge (6.2GB for a model with 2.7M trainable parameters), it seems that the model is checkpointed in place.
I am training on a system with 4 GPUs. In the simple case if I train the model passing 'accelerator' = 'dp' to the Trainer object, checkpoint happens with device='cuda:0', I can deal with that by passing map_location='cpu' to torch.load, inconvenient but manageable.
If I don't set 'accelerator' = 'dp' however it looks that ddp_spawn is chosen and I cannot figure out how to reload the checkpointed model. I get the error -- AssertionError: Default process group is not initialized, regardless of the map_location value. I found one post discussing setting up the environment, that I couldn't no figure out, but I am looking to just save model weights without all the location/environment inference so that it can be generically loaded on any device without much headache.
How do I tell pytorchlightning / ModelCheckpoint to transfer the model to the CPU before saving to facilitate easier loading? (and/or use mlflow to checkpoint the model)
Beta Was this translation helpful? Give feedback.
All reactions