|
35 | 35 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
36 | 36 | from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE |
37 | 37 | from pytorch_lightning.utilities.types import LRSchedulerTypeTuple |
38 | | -from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning, warning_cache |
| 38 | +from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning |
39 | 39 |
|
40 | 40 | if _DEEPSPEED_AVAILABLE: |
41 | 41 | import deepspeed |
@@ -671,19 +671,18 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: |
671 | 671 | checkpoint: The checkpoint state dictionary |
672 | 672 | filepath: write-target file's path |
673 | 673 | """ |
674 | | - if self.zero_stage_3 and self._multi_device and self.is_global_zero: |
675 | | - warning_cache.warn( |
676 | | - "When saving the DeepSpeed Stage 3 checkpoint, " |
677 | | - "each worker will save a shard of the checkpoint within a directory. " |
678 | | - "If a single file is required after training, " |
679 | | - "see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#" |
680 | | - "deepspeed-zero-stage-3-single-file for instructions." |
681 | | - ) |
682 | | - # Use deepspeed's internal checkpointing function to handle partitioned weights across processes |
683 | | - # dump states as a checkpoint dictionary object |
684 | | - _exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"] |
685 | | - checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} |
686 | | - self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint) |
| 674 | + if self.world_size > 1 and self.zero_stage_3: |
| 675 | + if self.save_full_weights: |
| 676 | + # todo: expose this as general function in deepspeed |
| 677 | + state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict() |
| 678 | + if self.is_global_zero: |
| 679 | + # State dict keys will include reference to wrapper LightningDeepSpeedModule |
| 680 | + # Delete `module` prefix before saving. |
| 681 | + state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()} |
| 682 | + checkpoint["state_dict"] = state_dict |
| 683 | + return super().save_checkpoint(checkpoint, filepath) |
| 684 | + return |
| 685 | + |
687 | 686 | # Use deepspeed's internal checkpointing function to handle partitioned weights across processes |
688 | 687 | # dump states as a checkpoint dictionary object |
689 | 688 | save_dir = self._filepath_to_dir(filepath) |
|
0 commit comments