-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
It looks like the default checkpoint connector doesn't handle DeepSpeed optimizer checkpointing properly. Among other issues, restore_training_state() (in pytorch_lightning==1.3.7.post0) passes DeepSpeed's load_state_dict() a dictionary, when it seems to expect a list.
Reproduction
To reproduce, train any model with DeepSpeed, using one of DeepSpeed's optimizers (I used FusedAdam) and create a checkpoint. Attempt to load that checkpoint with the Trainer's --restore_from_checkpoint option. That should case a crash.
Here's the trace I get:
Traceback (most recent call last):
File "dilated_resnet_pl.py", line 578, in <module>
trainer.fit(model_module, data_module)
File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
self._run(model)
File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run
self.dispatch()
File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch
self.accelerator.start_training(self)
File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
self.training_type_plugin.start_training(trainer)
File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
self._results = trainer.run_stage()
File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage
return self.run_train()
File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 837, in run_train
self._pre_training_routine()
File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 830, in _pre_training_routine
self.checkpoint_connector.restore_weights()
File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 73, in restore_weights
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU)
File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 102, in restore
self.restore_training_state(checkpoint, load_optimizer_states)
File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 183, in restore_training_state
optimizer.load_state_dict(opt_state)
File "/home/ga122/code/venv/lib/python3.6/site-packages/deepspeed/runtime/zero/stage2.py", line 1951, in load_state_dict
self.loss_scaler = state_dict_list[0]['loss_scaler']
KeyError: 0
At ZeRO stage 1, the issue can be fixed by simply wrapping opt_state in a list, as follows:
optimizer.load_state_dict(opt_state)
However; at higher levels of ZeRO optimization, when the optimizer state is partitioned, that doesn't cut it. In that case, it seems like the optimizer state is being stored differently from how DeepSpeed expects it: in deepspeed/runtime/zero/stage2.py, they iterate over the opt_state list passed to load_state_dict expecting there to be one item per partition. The checkpoint seems to actually contain one item with the state for all partitions (though the lengths don't exactly add up---I can't really figure out what's going wrong).
I'm running pytorch-lightning==1.3.3 and deepspeed==0.3.17+c1550b8 (compiled from source), though the issue is present in the current pip version of deepspeed and pytorch-lightning==1.3.7.post0.
#7282 is similar, but doesn't report this particular crash, or the fact that the ZeRO stage matters.