Skip to content

Commit 23495cc

Browse files
awaelchlilantiga
authored andcommitted
Fix state dict loading for ddp/dp in Fabric (#17997)
* fix state dict loading for ddp/dp * test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changelog * update test * move params to same device before equality test * test strategy --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit b14ddd9)
1 parent 3ba4ae7 commit 23495cc

File tree

7 files changed

+108
-5
lines changed

7 files changed

+108
-5
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3434
- Fixed an issue causing the `torch.set_float32_matmul_precision` info message to show multiple times ([#17960](https://github.com/Lightning-AI/lightning/pull/17960))
3535

3636

37+
- Fixed loading model state when `Fabric.load()` is called after `Fabric.setup()` ([#17997](https://github.com/Lightning-AI/lightning/pull/17997))
38+
39+
3740
## [2.0.3] - 2023-06-07
3841

3942
- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))

src/lightning/fabric/strategies/ddp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]
160160
module = module.module
161161
return super().get_module_state_dict(module)
162162

163+
def load_module_state_dict(
164+
self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True
165+
) -> None:
166+
if isinstance(module, DistributedDataParallel):
167+
module = module.module
168+
super().load_module_state_dict(module=module, state_dict=state_dict, strict=strict)
169+
163170
@classmethod
164171
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
165172
entries = (

src/lightning/fabric/strategies/dp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]
8989
module = module.module
9090
return super().get_module_state_dict(module)
9191

92+
def load_module_state_dict(
93+
self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True
94+
) -> None:
95+
if isinstance(module, DataParallel):
96+
module = module.module
97+
super().load_module_state_dict(module=module, state_dict=state_dict, strict=strict)
98+
9299
@classmethod
93100
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
94101
strategy_registry.register("dp", cls, description=cls.__class__.__name__)

src/lightning/fabric/strategies/strategy.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,12 @@ def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]
234234
"""Returns model state."""
235235
return module.state_dict()
236236

237+
def load_module_state_dict(
238+
self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True
239+
) -> None:
240+
"""Loads the given state into the model."""
241+
module.load_state_dict(state_dict, strict=strict)
242+
237243
def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
238244
"""Returns state of an optimizer.
239245
@@ -279,8 +285,7 @@ def load_checkpoint(
279285
continue
280286
if isinstance(obj, _Stateful):
281287
if isinstance(obj, Module):
282-
# TODO(fabric): Make strict loading configurable
283-
obj.load_state_dict(checkpoint.pop(name), strict=True)
288+
self.load_module_state_dict(module=obj, state_dict=checkpoint.pop(name), strict=True)
284289
else:
285290
obj.load_state_dict(checkpoint.pop(name))
286291
else:

tests/tests_fabric/strategies/test_ddp.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from copy import deepcopy
1415
from unittest import mock
1516
from unittest.mock import MagicMock, Mock
1617

@@ -83,7 +84,7 @@ def test_ddp_extra_kwargs(ddp_mock):
8384

8485

8586
def test_ddp_module_state_dict():
86-
"""Test that the module state dict gets retrieved without the prefixed wrapper keys from DDP."""
87+
"""Test that the module state dict can be retrieved and loaded without the prefixed wrapper keys from DDP."""
8788

8889
class DistributedDataParallelMock(MagicMock):
8990
def __instancecheck__(self, instance):
@@ -94,12 +95,18 @@ def __instancecheck__(self, instance):
9495

9596
# Without DDP applied (no setup call)
9697
original_module = torch.nn.Linear(2, 3)
97-
assert strategy.get_module_state_dict(original_module).keys() == original_module.state_dict().keys()
98+
original_state_dict = deepcopy(original_module.state_dict())
99+
retrieved_state_dict = strategy.get_module_state_dict(original_module)
100+
assert retrieved_state_dict.keys() == original_state_dict.keys()
101+
strategy.load_module_state_dict(original_module, retrieved_state_dict)
98102

99103
# With DDP applied (setup called)
100104
with mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel", DistributedDataParallelMock):
101105
wrapped_module = strategy.setup_module(original_module)
102-
assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()
106+
retrieved_state_dict = strategy.get_module_state_dict(wrapped_module)
107+
assert retrieved_state_dict.keys() == original_state_dict.keys()
108+
strategy.load_module_state_dict(wrapped_module, retrieved_state_dict)
109+
strategy.load_module_state_dict(wrapped_module, original_state_dict)
103110

104111

105112
@pytest.mark.parametrize(
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from copy import deepcopy
15+
16+
import pytest
17+
import torch
18+
19+
from lightning.fabric import Fabric
20+
from tests_fabric.helpers.runif import RunIf
21+
22+
23+
@pytest.mark.parametrize(
24+
"accelerator",
25+
[
26+
"cpu",
27+
pytest.param("cuda", marks=RunIf(min_cuda_gpus=2)),
28+
],
29+
)
30+
def test_ddp_save_load(accelerator, tmp_path):
31+
"""Test that DDP model checkpoints can be saved and loaded successfully."""
32+
fabric = Fabric(devices=2, accelerator=accelerator, strategy="ddp_spawn")
33+
fabric.launch(_run_ddp_save_load, tmp_path)
34+
35+
36+
def _run_ddp_save_load(fabric, tmp_path):
37+
fabric.seed_everything(0)
38+
39+
tmp_path = fabric.broadcast(tmp_path)
40+
41+
model = torch.nn.Linear(2, 2)
42+
params_before = deepcopy(list(model.parameters()))
43+
44+
# Save
45+
fabric.save(tmp_path / "saved_before_setup.ckpt", {"model": model})
46+
wrapped_model = fabric.setup(model)
47+
fabric.save(tmp_path / "saved_after_setup.ckpt", {"model": wrapped_model})
48+
49+
def assert_params_equal(params0, params1):
50+
assert all(torch.equal(p0, p1.to(p0.device)) for p0, p1 in zip(params0, params1))
51+
52+
# Load
53+
model = torch.nn.Linear(2, 2)
54+
fabric.load(tmp_path / "saved_before_setup.ckpt", {"model": model})
55+
assert_params_equal(params_before, model.parameters())
56+
fabric.load(tmp_path / "saved_after_setup.ckpt", {"model": model})
57+
assert_params_equal(params_before, model.parameters())
58+
59+
wrapped_model = fabric.setup(model)
60+
fabric.load(tmp_path / "saved_before_setup.ckpt", {"model": wrapped_model})
61+
assert_params_equal(params_before, wrapped_model.parameters())
62+
fabric.load(tmp_path / "saved_after_setup.ckpt", {"model": wrapped_model})
63+
assert_params_equal(params_before, wrapped_model.parameters())

tests/tests_fabric/strategies/test_strategy.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ def test_save_checkpoint_convert_stateful_objects(tmp_path):
6666
assert save_checkpoint_mock.call_args[1]["checkpoint"]["anything"] == expected["anything"]
6767

6868

69+
def test_load_module_state_dict():
70+
"""Test that `Strategy.load_module_state_dict()` calls `.load_state_dict()` on the module."""
71+
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
72+
module = Mock()
73+
state_dict = Mock()
74+
strategy.load_module_state_dict(module, state_dict)
75+
module.load_state_dict.assert_called_with(state_dict, strict=True)
76+
strategy.load_module_state_dict(module, state_dict, strict=False)
77+
module.load_state_dict.assert_called_with(state_dict, strict=False)
78+
79+
6980
def test_load_checkpoint_out_of_place(tmp_path):
7081
"""Test that one can load the full checkpoint into memory just like `torch.load()`."""
7182
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class

0 commit comments

Comments
 (0)