Skip to content

Commit 524a5b1

Browse files
awaelchlilantiga
authored andcommitted
Fix initialized weights resetting in Fabric.setup() when using FSDP (#19755)
1 parent 71b13c2 commit 524a5b1

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [2.2.2] - 2024-04-11
9+
10+
### Fixed
11+
12+
- Fixed a KeyError when saving a FSDP sharded checkpoint and setting `save_weights_only=True` ([#19524](https://github.com/Lightning-AI/pytorch-lightning/pull/19524))
13+
- Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627))
14+
- Fixed issue where some model methods couldn't be monkeypatched after being Fabric wrapped ([#19705](https://github.com/Lightning-AI/pytorch-lightning/pull/19705))
15+
- Fixed an issue causing weights to be reset in `Fabric.setup()` when using FSDP ([#19755](https://github.com/Lightning-AI/pytorch-lightning/pull/19755))
16+
817
## [2.2.1] - 2024-03-04
918

1019
### Fixed

src/lightning/fabric/utilities/device_dtype_mixin.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,12 @@ def half(self) -> Self:
109109
def _update_properties(
110110
root: torch.nn.Module, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
111111
) -> None:
112-
def apply_fn(module: Union[_DeviceDtypeModuleMixin, Module]) -> None:
112+
for module in root.modules():
113113
if not isinstance(module, _DeviceDtypeModuleMixin):
114-
return
114+
continue
115115
# cannot use `module.to()` because we don't actually want to move the model in case there are multiple
116116
# devices types (such as partial meta parameters)
117117
if device is not None:
118118
module._device = device
119119
if dtype is not None:
120120
module._dtype = dtype
121-
122-
root.apply(apply_fn)

tests/tests_fabric/strategies/test_fsdp_integration.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,3 +667,22 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):
667667
model, optimizer = fabric.setup(model, optimizer)
668668
state = {"model": model, "optimizer": optimizer, "steps": 1}
669669
fabric.load(checkpoint_path_full, state)
670+
671+
672+
@RunIf(min_cuda_gpus=2, standalone=True)
673+
def test_no_call_to_apply(monkeypatch):
674+
"""Regression test to ensure we're not calling `FSDP.apply()` indirectly (see #19755)."""
675+
monkeypatch.setattr(torch.distributed.fsdp.FullyShardedDataParallel, "apply", Mock())
676+
677+
fabric = Fabric(
678+
accelerator="cuda",
679+
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
680+
devices=2,
681+
)
682+
fabric.launch()
683+
684+
for setup_method in ("setup", "setup_module"):
685+
model = BoringModel()
686+
setup = getattr(fabric, setup_method)
687+
model = setup(model)
688+
model._forward_module.apply.assert_not_called()

0 commit comments

Comments
 (0)