Skip to content

Commit 6f6c07d

Browse files
authored
Revert removal of empty-parameters check for configure_optimizers() with FSDP (#18785)
1 parent 20ce3ae commit 6f6c07d

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,18 @@ def setup(self, trainer: "pl.Trainer") -> None:
327327
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
328328
if self.kwargs.get("use_orig_params"):
329329
return super().setup_optimizers(trainer)
330-
if any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
330+
331+
invalid_params_error = False
332+
try:
333+
# In PyTorch < 2.0, or if `use_orig_params=False` the user needs to do access
334+
# `self.trainer.model.parameters()` in configure_optimizers()
335+
super().setup_optimizers(trainer)
336+
except ValueError as ex:
337+
if "optimizer got an empty parameter list" not in str(ex):
338+
raise
339+
invalid_params_error = True
340+
341+
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
331342
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
332343
raise ValueError(
333344
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,16 +359,22 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg):
359359

360360

361361
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
362-
def test_invalid_parameters_in_optimizer():
362+
@pytest.mark.parametrize("use_orig_params", [None, False, True])
363+
def test_invalid_parameters_in_optimizer(use_orig_params):
364+
fsdp_kwargs = {}
365+
if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not None:
366+
fsdp_kwargs = {"use_orig_params": use_orig_params}
367+
363368
trainer = Trainer(
364-
strategy="fsdp",
369+
strategy=FSDPStrategy(**fsdp_kwargs),
365370
accelerator="cuda",
366371
devices=1,
367372
fast_dev_run=1,
368373
)
374+
369375
error_context = (
370376
nullcontext()
371-
if _TORCH_GREATER_EQUAL_2_0
377+
if _TORCH_GREATER_EQUAL_2_0 and (_TORCH_GREATER_EQUAL_2_1 or use_orig_params is not False)
372378
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
373379
)
374380

@@ -385,6 +391,12 @@ def configure_optimizers(self):
385391
layer = torch.nn.Linear(4, 5)
386392
return torch.optim.Adam(layer.parameters(), lr=1e-2)
387393

394+
error_context = (
395+
nullcontext()
396+
if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not False
397+
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
398+
)
399+
388400
model = NoFlatParametersModel()
389401
with error_context:
390402
trainer.fit(model)

tests/tests_pytorch/trainer/optimization/test_manual_optimization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def on_before_optimizer_step(self, optimizer, *_):
388388

389389
def test_step_with_optimizer_closure(tmpdir):
390390
"""Tests that `step` works with optimizer_closure."""
391+
seed_everything(1)
391392

392393
class TestModel(BoringModel):
393394
_losses = []

0 commit comments

Comments
 (0)