Skip to content

Commit 3e6df2f

Browse files
justusschocklexierule
authored andcommitted
Remove todo, ensure we only check rank 0 for deepspeed warning (#9311)
1 parent f2c5f5b commit 3e6df2f

File tree

3 files changed

+64
-14
lines changed

3 files changed

+64
-14
lines changed

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3636
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
3737
from pytorch_lightning.utilities.types import LRSchedulerTypeTuple
38-
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning
38+
from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning, warning_cache
3939

4040
if _DEEPSPEED_AVAILABLE:
4141
import deepspeed
@@ -671,18 +671,19 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
671671
checkpoint: The checkpoint state dictionary
672672
filepath: write-target file's path
673673
"""
674-
if self.world_size > 1 and self.zero_stage_3:
675-
if self.save_full_weights:
676-
# todo: expose this as general function in deepspeed
677-
state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict()
678-
if self.is_global_zero:
679-
# State dict keys will include reference to wrapper LightningDeepSpeedModule
680-
# Delete `module` prefix before saving.
681-
state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()}
682-
checkpoint["state_dict"] = state_dict
683-
return super().save_checkpoint(checkpoint, filepath)
684-
return
685-
674+
if self.zero_stage_3 and self._multi_device and self.is_global_zero:
675+
warning_cache.warn(
676+
"When saving the DeepSpeed Stage 3 checkpoint, "
677+
"each worker will save a shard of the checkpoint within a directory. "
678+
"If a single file is required after training, "
679+
"see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#"
680+
"deepspeed-zero-stage-3-single-file for instructions."
681+
)
682+
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
683+
# dump states as a checkpoint dictionary object
684+
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
685+
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
686+
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)
686687
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
687688
# dump states as a checkpoint dictionary object
688689
save_dir = self._filepath_to_dir(filepath)

tests/plugins/test_deepspeed_plugin.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,55 @@ def test_deepspeed_fp32_works(tmpdir):
419419
trainer.fit(model)
420420

421421

422+
@RunIf(min_gpus=2, deepspeed=True, special=True)
423+
def test_deepspeed_stage_3_save_warning(tmpdir):
424+
"""Test to ensure that DeepSpeed Stage 3 gives a warning when saving on rank zero."""
425+
model = BoringModel()
426+
trainer = Trainer(
427+
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16
428+
)
429+
trainer.fit(model)
430+
checkpoint_path = os.path.join(tmpdir, "model.pt")
431+
with pytest.warns(UserWarning) as record:
432+
# both ranks need to call save checkpoint
433+
trainer.save_checkpoint(checkpoint_path)
434+
if trainer.is_global_zero:
435+
assert len(record) == 1
436+
match = "each worker will save a shard of the checkpoint within a directory."
437+
assert match in str(record[0].message)
438+
439+
440+
@RunIf(min_gpus=1, deepspeed=True, special=True)
441+
def test_deepspeed_multigpu_single_file(tmpdir):
442+
"""Test to ensure that DeepSpeed loads from a single file checkpoint."""
443+
model = BoringModel()
444+
checkpoint_path = os.path.join(tmpdir, "model.pt")
445+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
446+
trainer.fit(model)
447+
trainer.save_checkpoint(checkpoint_path)
448+
449+
trainer = Trainer(
450+
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=1, fast_dev_run=True, precision=16
451+
)
452+
plugin = trainer.training_type_plugin
453+
assert isinstance(plugin, DeepSpeedPlugin)
454+
assert not plugin.load_full_weights
455+
with pytest.raises(MisconfigurationException, match="DeepSpeed was unable to load the checkpoint."):
456+
trainer.test(model, ckpt_path=checkpoint_path)
457+
458+
trainer = Trainer(
459+
default_root_dir=tmpdir,
460+
plugins=[DeepSpeedPlugin(stage=3, load_full_weights=True)],
461+
gpus=1,
462+
fast_dev_run=True,
463+
precision=16,
464+
)
465+
plugin = trainer.training_type_plugin
466+
assert isinstance(plugin, DeepSpeedPlugin)
467+
assert plugin.load_full_weights
468+
trainer.test(model, ckpt_path=checkpoint_path)
469+
470+
422471
class ModelParallelClassificationModel(LightningModule):
423472
def __init__(self, lr: float = 0.01, num_blocks: int = 5):
424473
super().__init__()

tests/trainer/loops/test_training_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def training_epoch_end(self, outputs) -> None:
193193

194194

195195
def test_batch_loop_releases_loss(tmpdir):
196-
"""Test that loss/graph is released so that it can be garbage collected before the next training step"""
196+
"""Test that loss/graph is released so that it can be garbage collected before the next training step."""
197197

198198
class TestModel(BoringModel):
199199
def training_step(self, batch, batch_idx):

0 commit comments

Comments
 (0)