Skip to content

DeepSpeed + training checkpointing doesn't work #8092

@gahdritz

Description

@gahdritz

🐛 Bug

It looks like the default checkpoint connector doesn't handle DeepSpeed optimizer checkpointing properly. Among other issues, restore_training_state() (in pytorch_lightning==1.3.7.post0) passes DeepSpeed's load_state_dict() a dictionary, when it seems to expect a list.

Reproduction

To reproduce, train any model with DeepSpeed, using one of DeepSpeed's optimizers (I used FusedAdam) and create a checkpoint. Attempt to load that checkpoint with the Trainer's --restore_from_checkpoint option. That should case a crash.

Here's the trace I get:

Traceback (most recent call last):
  File "dilated_resnet_pl.py", line 578, in <module>
    trainer.fit(model_module, data_module)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
    self._run(model)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run
    self.dispatch()
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch
    self.accelerator.start_training(self)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage
    return self.run_train()
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 837, in run_train
    self._pre_training_routine()
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 830, in _pre_training_routine
    self.checkpoint_connector.restore_weights()
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 73, in restore_weights
    self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 102, in restore
    self.restore_training_state(checkpoint, load_optimizer_states)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 183, in restore_training_state
    optimizer.load_state_dict(opt_state)
  File "/home/ga122/code/venv/lib/python3.6/site-packages/deepspeed/runtime/zero/stage2.py", line 1951, in load_state_dict
    self.loss_scaler = state_dict_list[0]['loss_scaler']
KeyError: 0

At ZeRO stage 1, the issue can be fixed by simply wrapping opt_state in a list, as follows:

optimizer.load_state_dict(opt_state)

However; at higher levels of ZeRO optimization, when the optimizer state is partitioned, that doesn't cut it. In that case, it seems like the optimizer state is being stored differently from how DeepSpeed expects it: in deepspeed/runtime/zero/stage2.py, they iterate over the opt_state list passed to load_state_dict expecting there to be one item per partition. The checkpoint seems to actually contain one item with the state for all partitions (though the lengths don't exactly add up---I can't really figure out what's going wrong).

I'm running pytorch-lightning==1.3.3 and deepspeed==0.3.17+c1550b8 (compiled from source), though the issue is present in the current pip version of deepspeed and pytorch-lightning==1.3.7.post0.

#7282 is similar, but doesn't report this particular crash, or the fact that the ZeRO stage matters.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedOpen to be worked onpriority: 1Medium priority task

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions