Skip to content

Commit 9a97fcc

Browse files
speediedanBordaawaelchli
authored andcommitted
Generalize Optimizer validation to accommodate both FSDP 1.x and 2.x (#16733)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 6336d6a commit 9a97fcc

File tree

4 files changed

+72
-18
lines changed

4 files changed

+72
-18
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7-
## [2.0.1] - 2023-03-21
7+
8+
## [2.0.1] - 2023-03-30
9+
10+
### Changed
11+
12+
- Generalized `Optimizer` validation to accommodate both FSDP 1.x and 2.x ([#16733](https://github.com/Lightning-AI/lightning/pull/16733))
813

914

1015
## [2.0.0] - 2023-03-15

src/lightning/fabric/strategies/fsdp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,10 @@ def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUO
348348

349349

350350
def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
351-
from torch.distributed.fsdp import FlatParameter
351+
_FSDP_FLATTENED = "_fsdp_flattened"
352+
if _TORCH_GREATER_EQUAL_1_13:
353+
return any(getattr(param, _FSDP_FLATTENED, False) for param in optimizer.param_groups[0]["params"])
354+
else:
355+
from torch.distributed.fsdp import FlatParameter
352356

353-
return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])
357+
return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99
### Changed
1010

1111
- Pickling the `LightningModule` no longer pickles the `Trainer` ([#17133](https://github.com/Lightning-AI/lightning/pull/17133))
12+
- Generalized `Optimizer` validation to accommodate both FSDP 1.x and 2.x ([#16733](https://github.com/Lightning-AI/lightning/pull/16733))
1213

1314
### Fixed
1415

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
2-
from typing import Any, Dict, Optional
2+
from functools import partial
3+
from typing import Any, Callable, Dict, Optional
34
from unittest import mock
45
from unittest.mock import ANY, Mock
56

@@ -18,7 +19,14 @@
1819

1920
if _TORCH_GREATER_EQUAL_1_12:
2021
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
21-
from torch.distributed.fsdp.wrap import wrap
22+
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, wrap
23+
else:
24+
size_based_auto_wrap_policy = object
25+
26+
if _TORCH_GREATER_EQUAL_2_0:
27+
from torch.distributed.fsdp.wrap import _FSDPPolicy
28+
else:
29+
_FSDPPolicy = object
2230

2331

2432
class TestFSDPModel(BoringModel):
@@ -117,17 +125,18 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
117125

118126
_assert_save_equality(trainer, model_path, cls=model.__class__)
119127

120-
# Test entry point
121-
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`
128+
with torch.inference_mode():
129+
# Test entry point
130+
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`
122131

123-
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
124-
trainer.test(ckpt_path=model_path)
132+
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
133+
trainer.test(ckpt_path=model_path)
125134

126-
# Predict entry point
127-
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
135+
# Predict entry point
136+
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
128137

129-
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
130-
trainer.predict(ckpt_path=model_path)
138+
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
139+
trainer.predict(ckpt_path=model_path)
131140

132141

133142
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
@@ -200,6 +209,20 @@ def test_fsdp_strategy_checkpoint(tmpdir, precision):
200209
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))
201210

202211

212+
class CustomWrapPolicy(_FSDPPolicy):
213+
"""This is a wrapper around :func:`_module_wrap_policy`."""
214+
215+
def __init__(self, min_num_params: int):
216+
self._policy: Callable = partial(size_based_auto_wrap_policy, min_num_params=min_num_params)
217+
218+
@property
219+
def policy(self):
220+
return self._policy
221+
222+
223+
custom_fsdp_policy = CustomWrapPolicy(min_num_params=2)
224+
225+
203226
if _TORCH_GREATER_EQUAL_2_0:
204227

205228
def custom_auto_wrap_policy(
@@ -221,19 +244,40 @@ def custom_auto_wrap_policy(
221244

222245
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
223246
@pytest.mark.parametrize(
224-
"model, strategy",
247+
"model, strategy, strategy_cfg",
225248
[
226-
(TestFSDPModel(), "fsdp"),
227-
(TestFSDPModelAutoWrapped(), FSDPStrategy),
249+
pytest.param(TestFSDPModel(), "fsdp", None, id="manually_wrapped"),
250+
pytest.param(
251+
TestFSDPModelAutoWrapped(),
252+
FSDPStrategy,
253+
{"auto_wrap_policy": custom_auto_wrap_policy},
254+
marks=RunIf(max_torch="2.0.0"),
255+
id="autowrap_1x",
256+
),
257+
pytest.param(
258+
TestFSDPModelAutoWrapped(),
259+
FSDPStrategy,
260+
{"auto_wrap_policy": custom_auto_wrap_policy},
261+
marks=RunIf(min_torch="2.0.0"),
262+
id="autowrap_2x",
263+
),
264+
pytest.param(
265+
TestFSDPModelAutoWrapped(),
266+
FSDPStrategy,
267+
{"auto_wrap_policy": custom_fsdp_policy, "use_orig_params": True},
268+
marks=RunIf(min_torch="2.0.0"),
269+
id="autowrap_use_orig_params",
270+
),
228271
],
229272
)
230-
def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy):
273+
def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg):
231274
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""
232275

233276
ck = ModelCheckpoint(save_last=True)
234277

278+
strategy_cfg = strategy_cfg or {}
235279
if not isinstance(strategy, str):
236-
strategy = strategy(auto_wrap_policy=custom_auto_wrap_policy)
280+
strategy = strategy(**strategy_cfg)
237281

238282
trainer = Trainer(
239283
default_root_dir=tmpdir,

0 commit comments

Comments
 (0)