Skip to content

Commit b3275e0

Browse files
dimitri-voytandvoytan-sparkawaelchliBorda
authored
Sharded state dicts save correctly when save_weights_only=True (#19524)
Co-authored-by: Dimitri <[email protected]> Co-authored-by: awaelchli <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 8549a93 commit b3275e0

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4242

4343
### Fixed
4444

45-
-
45+
- Fixed a KeyError when saving a FSDP sharded checkpoint and setting `save_weights_only=True` ([#19524](https://github.com/Lightning-AI/pytorch-lightning/pull/19524))
46+
4647

4748
-
4849

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,8 @@ def save_checkpoint(
568568

569569
converted_state = {"model": checkpoint.pop("state_dict")}
570570
converted_state.update({
571-
f"optimizer_{idx}": optim_state for idx, optim_state in enumerate(checkpoint.pop("optimizer_states"))
571+
f"optimizer_{idx}": optim_state
572+
for idx, optim_state in enumerate(checkpoint.pop("optimizer_states", []))
572573
})
573574

574575
_distributed_checkpoint_save(converted_state, path)

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
185185
trainer.save_checkpoint(model_path.with_name("after-test"))
186186
trainer.save_checkpoint(model_path, weights_only=True)
187187

188-
_assert_save_equality(trainer, model_path, cls=model.__class__)
188+
if not model_path.is_dir(): # TODO (@awaelchli): Add support for asserting equality of sharded checkpoints
189+
_assert_save_equality(trainer, model_path, cls=model.__class__)
189190

190191
with torch.inference_mode():
191192
# Test entry point
@@ -279,11 +280,13 @@ def training_step(self, batch, batch_idx):
279280

280281
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
281282
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
282-
def test_fsdp_strategy_checkpoint(tmpdir, precision):
283+
@pytest.mark.parametrize("state_dict_type", ["sharded", "full"])
284+
def test_fsdp_strategy_checkpoint(state_dict_type, precision, tmpdir):
283285
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
284286
model = TestFSDPModel()
287+
strategy = FSDPStrategy(state_dict_type=state_dict_type)
285288
trainer = Trainer(
286-
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="fsdp", precision=precision, max_epochs=1
289+
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy=strategy, precision=precision, max_epochs=1
287290
)
288291
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))
289292

0 commit comments

Comments
 (0)