Skip to content

Commit 97ddcb1

Browse files
awaelchlilantiga
authored andcommitted
Fix trainer.save_checkpoint after trainer.test with FSDP (#18992)
(cherry picked from commit 3acea8d)
1 parent deddb0a commit 97ddcb1

File tree

4 files changed

+17
-4
lines changed

4 files changed

+17
-4
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
-
1818

19-
2019
## [2.1.1] - 2023-11-06
2120

2221
### Fixed

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,11 @@ def setup(self, trainer: "pl.Trainer") -> None:
329329
self.setup_precision_plugin()
330330

331331
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
332+
# If we're setting up for evaluation after fitting, we need to discard the optimizers
333+
# since we're rewrapping the model, otherwise optimizer param references are no longer valid
334+
# and subsequent checkpoint saving can fail
335+
self._reset_optimizers_and_schedulers()
336+
332337
if self.kwargs.get("use_orig_params"):
333338
return super().setup_optimizers(trainer)
334339

src/lightning/pytorch/strategies/strategy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,11 @@ def on_exception(self, exception: BaseException) -> None:
575575
"""Called when the trainer execution is interrupted by an exception."""
576576
pass
577577

578+
def _reset_optimizers_and_schedulers(self) -> None:
579+
self._optimizers = []
580+
self._lightning_optimizers = []
581+
self.lr_scheduler_configs = []
582+
578583
def __getstate__(self) -> Dict:
579584
# `LightningOptimizer` overrides `self.__class__` so they cannot be pickled
580585
state = dict(vars(self)) # copy

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,13 @@ def _assert_layer_fsdp_instance(self) -> None:
173173

174174
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
175175
trainer.fit(model)
176+
trainer.test(model)
177+
176178
model_path = trainer.strategy.broadcast(model_path)
177-
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
179+
model_path = Path(model_path if model_path else trainer.checkpoint_callback.last_model_path)
178180

181+
# Save another checkpoint after testing, without optimizer states
182+
trainer.save_checkpoint(model_path.with_name("after-test"))
179183
trainer.save_checkpoint(model_path, weights_only=True)
180184

181185
_assert_save_equality(trainer, model_path, cls=model.__class__)
@@ -270,13 +274,13 @@ def training_step(self, batch, batch_idx):
270274
trainer.fit(model)
271275

272276

273-
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
277+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
274278
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
275279
def test_fsdp_strategy_checkpoint(tmpdir, precision):
276280
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
277281
model = TestFSDPModel()
278282
trainer = Trainer(
279-
default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp", precision=precision, max_epochs=1
283+
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="fsdp", precision=precision, max_epochs=1
280284
)
281285
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))
282286

0 commit comments

Comments
 (0)